From 21c2dcd82181b9be9c2108b5f6e9619c1fc28496 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 8 Oct 2020 17:47:43 -0700
Subject: [PATCH] Refactors parts of AlternateMemoryBestFitHeap to make a few
 private member variables and methods protected so that they can be reused by
 subclasses.

PiperOrigin-RevId: 336200966
Change-Id: Ia2faeac959e7ace3ddf1a6a7ee434714f90259b2
---
 .../xla/service/memory_space_assignment.cc    |  7 +-
 .../xla/service/memory_space_assignment.h     | 66 ++++++++++++-------
 2 files changed, 48 insertions(+), 25 deletions(-)

diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index 6d4b0e65010..efee06fdbf3 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -709,12 +709,13 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues(
 }
 
 void AlternateMemoryBestFitHeap::FindAliases(
-    std::vector<AllocationValue>* allocation_values) const {
+    std::vector<AllocationValue>* allocation_values,
+    bool skip_values_with_no_uses) const {
   absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
       values_by_defining_inst;
   for (AllocationValue& value : *allocation_values) {
     // Skip the value if it doesn't have any uses.
-    if (value.uses().empty()) {
+    if (value.uses().empty() && skip_values_with_no_uses) {
       continue;
     }
     CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
@@ -1157,7 +1158,7 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
   for (const auto& colocated_interval : colocated_intervals) {
     CreateAllocationValues(*colocated_interval, allocation_values);
   }
-  FindAliases(&allocation_values);
+  FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true);
 }
 
 AlternateMemoryBestFitHeap::Result
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 409a44d319d..b1f59fa9c78 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -728,6 +728,16 @@ class MemorySpaceAssignment {
       // All the positions where this use aliases with. The aliased positions
       // must get the same allocation.
       std::vector<HloPosition> aliases;
+
+      bool operator==(const Use& other) const {
+        return hlo_use == other.hlo_use && time == other.time &&
+               aliases == other.aliases;
+      }
+
+      template <typename H>
+      friend H AbslHashValue(H h, const Use& s) {
+        return H::combine(std::move(h), s.hlo_use, s.time, s.aliases);
+      }
     };
 
     AllocationValue(const HloValue* value, const HloPosition& position,
@@ -823,6 +833,8 @@ class MemorySpaceAssignment {
 
   AllocationSequence allocations_;
 
+  HloModule* module() { return module_; }
+
  private:
   // Process calls Process methods of the allocations after the allocations have
   // been finalized.
@@ -949,6 +961,38 @@ class AlternateMemoryBestFitHeap
 
   HeapSimulator::Result<HloValue> Finish() override;
 
+ protected:
+  // Given a buffer interval, returns the colocated intervals. Unlike the
+  // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
+  // returns the colocated intervals sorted by scheduled time.
+  std::vector<const BufferInterval*> GetSortedColocatedIntervals(
+      const BufferInterval& interval) const;
+
+  // Given a BufferInterval, creates AllocationValue objects and corresponding
+  // AllocationSequences and appends them into allocation_sequence_list_.
+  void CreateAllocationValues(
+      const BufferInterval& buffer_interval,
+      std::vector<AllocationValue>& allocation_values) const;
+
+  // Given colocated intervals, populates allocation_values with the
+  // corresponding AllocationValue objects.
+  void CreateAllocationValuesFromColocatedIntervals(
+      absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
+          colocated_intervals,
+      std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values);
+
+  // Go through all the uses in the AllocationValues and find the aliasing
+  // positions.
+  void FindAliases(std::vector<AllocationValue>* allocation_values,
+                   bool skip_values_with_no_uses) const;
+
+  MemorySpaceAssignment::AllocationSequence* allocations() {
+    return allocations_;
+  }
+  const MemorySpaceAssignment::Options& options() { return options_; }
+  const HloAliasAnalysis& alias_analysis() { return alias_analysis_; }
+  const HloLiveRange& hlo_live_range() { return hlo_live_range_; }
+
  private:
   // We inherit AllocationBlock struct to attach the Allocation information to
   // make importing repacked offsets easier.
@@ -1096,18 +1140,6 @@ class AlternateMemoryBestFitHeap
   bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
                                      const HloUse& use) const;
 
-  // Given a BufferInterval, creates AllocationValue objects and corresponding
-  // AllocationSequences and appends them into allocation_sequence_list_.
-  void CreateAllocationValues(
-      const BufferInterval& buffer_interval,
-      std::vector<AllocationValue>& allocation_values) const;
-
-  // Given colocated intervals, populates allocation_values with the
-  // corresponding AllocationValue objects.
-  void CreateAllocationValuesFromColocatedIntervals(
-      absl::Span<const BufferInterval* const> colocated_intervals,
-      std::vector<AllocationValue>& allocation_values);
-
   // Finds allocations for allocation values generated from colocated intervals.
   // All of the allocation values have a must-alias relationship with each
   // other. Returns either kSuccess if all of the sites could be placed in the
@@ -1115,10 +1147,6 @@ class AlternateMemoryBestFitHeap
   Result AllocateAllocationValues(
       absl::Span<AllocationValue> allocation_values);
 
-  // Go through all the uses in the AllocationValues and find the aliasing
-  // positions.
-  void FindAliases(std::vector<AllocationValue>* allocation_values) const;
-
   // Finds an allocation for an allocation request for a segment (see the
   // documentation for AllocationRequest above how a segment is defined).
   //
@@ -1194,12 +1222,6 @@ class AlternateMemoryBestFitHeap
   bool AreIntervalsReservedInAlternateMemory(
       absl::Span<const BufferInterval* const> colocated_intervals) const;
 
-  // Given a buffer interval, returns the colocated intervals. Unlike the
-  // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
-  // returns the colocated intervals sorted by scheduled time.
-  std::vector<const BufferInterval*> GetSortedColocatedIntervals(
-      const BufferInterval& interval) const;
-
   // Since the allocations are recorded to the AllocationSequence, we don't
   // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
   // to avoid unnecessarily adding the chunk to the chunk map.