diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 6d4b0e65010..efee06fdbf3 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -709,12 +709,13 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( } void AlternateMemoryBestFitHeap::FindAliases( - std::vector* allocation_values) const { + std::vector* allocation_values, + bool skip_values_with_no_uses) const { absl::flat_hash_map values_by_defining_inst; for (AllocationValue& value : *allocation_values) { // Skip the value if it doesn't have any uses. - if (value.uses().empty()) { + if (value.uses().empty() && skip_values_with_no_uses) { continue; } CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0); @@ -1157,7 +1158,7 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( for (const auto& colocated_interval : colocated_intervals) { CreateAllocationValues(*colocated_interval, allocation_values); } - FindAliases(&allocation_values); + FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true); } AlternateMemoryBestFitHeap::Result diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 409a44d319d..b1f59fa9c78 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -728,6 +728,16 @@ class MemorySpaceAssignment { // All the positions where this use aliases with. The aliased positions // must get the same allocation. std::vector aliases; + + bool operator==(const Use& other) const { + return hlo_use == other.hlo_use && time == other.time && + aliases == other.aliases; + } + + template + friend H AbslHashValue(H h, const Use& s) { + return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); + } }; AllocationValue(const HloValue* value, const HloPosition& position, @@ -823,6 +833,8 @@ class MemorySpaceAssignment { AllocationSequence allocations_; + HloModule* module() { return module_; } + private: // Process calls Process methods of the allocations after the allocations have // been finalized. @@ -949,6 +961,38 @@ class AlternateMemoryBestFitHeap HeapSimulator::Result Finish() override; + protected: + // Given a buffer interval, returns the colocated intervals. Unlike the + // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it + // returns the colocated intervals sorted by scheduled time. + std::vector GetSortedColocatedIntervals( + const BufferInterval& interval) const; + + // Given a BufferInterval, creates AllocationValue objects and corresponding + // AllocationSequences and appends them into allocation_sequence_list_. + void CreateAllocationValues( + const BufferInterval& buffer_interval, + std::vector& allocation_values) const; + + // Given colocated intervals, populates allocation_values with the + // corresponding AllocationValue objects. + void CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values); + + // Go through all the uses in the AllocationValues and find the aliasing + // positions. + void FindAliases(std::vector* allocation_values, + bool skip_values_with_no_uses) const; + + MemorySpaceAssignment::AllocationSequence* allocations() { + return allocations_; + } + const MemorySpaceAssignment::Options& options() { return options_; } + const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } + const HloLiveRange& hlo_live_range() { return hlo_live_range_; } + private: // We inherit AllocationBlock struct to attach the Allocation information to // make importing repacked offsets easier. @@ -1096,18 +1140,6 @@ class AlternateMemoryBestFitHeap bool IsUseAllowedInAlternateMemory(const AllocationValue& value, const HloUse& use) const; - // Given a BufferInterval, creates AllocationValue objects and corresponding - // AllocationSequences and appends them into allocation_sequence_list_. - void CreateAllocationValues( - const BufferInterval& buffer_interval, - std::vector& allocation_values) const; - - // Given colocated intervals, populates allocation_values with the - // corresponding AllocationValue objects. - void CreateAllocationValuesFromColocatedIntervals( - absl::Span colocated_intervals, - std::vector& allocation_values); - // Finds allocations for allocation values generated from colocated intervals. // All of the allocation values have a must-alias relationship with each // other. Returns either kSuccess if all of the sites could be placed in the @@ -1115,10 +1147,6 @@ class AlternateMemoryBestFitHeap Result AllocateAllocationValues( absl::Span allocation_values); - // Go through all the uses in the AllocationValues and find the aliasing - // positions. - void FindAliases(std::vector* allocation_values) const; - // Finds an allocation for an allocation request for a segment (see the // documentation for AllocationRequest above how a segment is defined). // @@ -1194,12 +1222,6 @@ class AlternateMemoryBestFitHeap bool AreIntervalsReservedInAlternateMemory( absl::Span colocated_intervals) const; - // Given a buffer interval, returns the colocated intervals. Unlike the - // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it - // returns the colocated intervals sorted by scheduled time. - std::vector GetSortedColocatedIntervals( - const BufferInterval& interval) const; - // Since the allocations are recorded to the AllocationSequence, we don't // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap // to avoid unnecessarily adding the chunk to the chunk map.