Refactors parts of AlternateMemoryBestFitHeap to make a few private member variables and methods protected so that they can be reused by subclasses.

PiperOrigin-RevId: 336200966
Change-Id: Ia2faeac959e7ace3ddf1a6a7ee434714f90259b2
This commit is contained in:
A. Unique TensorFlower 2020-10-08 17:47:43 -07:00 committed by TensorFlower Gardener
parent 2d83c3230f
commit 21c2dcd821
2 changed files with 48 additions and 25 deletions

View File

@ -709,12 +709,13 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues(
}
void AlternateMemoryBestFitHeap::FindAliases(
std::vector<AllocationValue>* allocation_values) const {
std::vector<AllocationValue>* allocation_values,
bool skip_values_with_no_uses) const {
absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
values_by_defining_inst;
for (AllocationValue& value : *allocation_values) {
// Skip the value if it doesn't have any uses.
if (value.uses().empty()) {
if (value.uses().empty() && skip_values_with_no_uses) {
continue;
}
CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
@ -1157,7 +1158,7 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
for (const auto& colocated_interval : colocated_intervals) {
CreateAllocationValues(*colocated_interval, allocation_values);
}
FindAliases(&allocation_values);
FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true);
}
AlternateMemoryBestFitHeap::Result

View File

@ -728,6 +728,16 @@ class MemorySpaceAssignment {
// All the positions where this use aliases with. The aliased positions
// must get the same allocation.
std::vector<HloPosition> aliases;
bool operator==(const Use& other) const {
return hlo_use == other.hlo_use && time == other.time &&
aliases == other.aliases;
}
template <typename H>
friend H AbslHashValue(H h, const Use& s) {
return H::combine(std::move(h), s.hlo_use, s.time, s.aliases);
}
};
AllocationValue(const HloValue* value, const HloPosition& position,
@ -823,6 +833,8 @@ class MemorySpaceAssignment {
AllocationSequence allocations_;
HloModule* module() { return module_; }
private:
// Process calls Process methods of the allocations after the allocations have
// been finalized.
@ -949,6 +961,38 @@ class AlternateMemoryBestFitHeap
HeapSimulator::Result<HloValue> Finish() override;
protected:
// Given a buffer interval, returns the colocated intervals. Unlike the
// similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
// returns the colocated intervals sorted by scheduled time.
std::vector<const BufferInterval*> GetSortedColocatedIntervals(
const BufferInterval& interval) const;
// Given a BufferInterval, creates AllocationValue objects and corresponding
// AllocationSequences and appends them into allocation_sequence_list_.
void CreateAllocationValues(
const BufferInterval& buffer_interval,
std::vector<AllocationValue>& allocation_values) const;
// Given colocated intervals, populates allocation_values with the
// corresponding AllocationValue objects.
void CreateAllocationValuesFromColocatedIntervals(
absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
colocated_intervals,
std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values);
// Go through all the uses in the AllocationValues and find the aliasing
// positions.
void FindAliases(std::vector<AllocationValue>* allocation_values,
bool skip_values_with_no_uses) const;
MemorySpaceAssignment::AllocationSequence* allocations() {
return allocations_;
}
const MemorySpaceAssignment::Options& options() { return options_; }
const HloAliasAnalysis& alias_analysis() { return alias_analysis_; }
const HloLiveRange& hlo_live_range() { return hlo_live_range_; }
private:
// We inherit AllocationBlock struct to attach the Allocation information to
// make importing repacked offsets easier.
@ -1096,18 +1140,6 @@ class AlternateMemoryBestFitHeap
bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
const HloUse& use) const;
// Given a BufferInterval, creates AllocationValue objects and corresponding
// AllocationSequences and appends them into allocation_sequence_list_.
void CreateAllocationValues(
const BufferInterval& buffer_interval,
std::vector<AllocationValue>& allocation_values) const;
// Given colocated intervals, populates allocation_values with the
// corresponding AllocationValue objects.
void CreateAllocationValuesFromColocatedIntervals(
absl::Span<const BufferInterval* const> colocated_intervals,
std::vector<AllocationValue>& allocation_values);
// Finds allocations for allocation values generated from colocated intervals.
// All of the allocation values have a must-alias relationship with each
// other. Returns either kSuccess if all of the sites could be placed in the
@ -1115,10 +1147,6 @@ class AlternateMemoryBestFitHeap
Result AllocateAllocationValues(
absl::Span<AllocationValue> allocation_values);
// Go through all the uses in the AllocationValues and find the aliasing
// positions.
void FindAliases(std::vector<AllocationValue>* allocation_values) const;
// Finds an allocation for an allocation request for a segment (see the
// documentation for AllocationRequest above how a segment is defined).
//
@ -1194,12 +1222,6 @@ class AlternateMemoryBestFitHeap
bool AreIntervalsReservedInAlternateMemory(
absl::Span<const BufferInterval* const> colocated_intervals) const;
// Given a buffer interval, returns the colocated intervals. Unlike the
// similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
// returns the colocated intervals sorted by scheduled time.
std::vector<const BufferInterval*> GetSortedColocatedIntervals(
const BufferInterval& interval) const;
// Since the allocations are recorded to the AllocationSequence, we don't
// maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
// to avoid unnecessarily adding the chunk to the chunk map.