From 02f3c946f445d68f7a16f51ede587318a1d164b4 Mon Sep 17 00:00:00 2001
From: Yunxing Dai <yunxing@google.com>
Date: Wed, 21 Aug 2019 17:52:00 -0700
Subject: [PATCH] Compressing rematerialization

Adds a new kind of rematerialization that compresses the node into a compact form, uncompresses it back at a later program point.

PiperOrigin-RevId: 264734096
---
 .../xla/service/hlo_rematerialization.cc      | 558 ++++++++++++++----
 .../xla/service/hlo_rematerialization.h       |  26 +-
 .../xla/service/hlo_rematerialization_test.cc | 137 ++++-
 3 files changed, 596 insertions(+), 125 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index d362317495e..aa723797da1 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -100,6 +100,17 @@ bool CanBeRematerialized(
 using BufferId = int64;
 using BufferIdList = absl::InlinedVector<BufferId, 3>;
 
+struct RematStrategy {
+  enum {
+    // Recompute the node at a later program point.
+    kRecompute,
+    // Change the layout into a compact form and uncompress it back at a later
+    // program point.
+    kCompress,
+  } kind;
+  Shape compact_shape;
+};
+
 // We wrap HloInstruction* with an Item that holds auxiliary
 // per-instruction state.
 struct Item {
@@ -117,6 +128,10 @@ struct Item {
   // The buffers defined by this instruction.
   BufferIdList buffers_defined;
 
+  // Output buffers of this instruction. This is used to track outputs by GTE
+  // instructions (where the instruction doesn't define a buffer).
+  BufferIdList buffers_output;
+
   // The buffers used by this instruction.
   BufferIdList buffers_used;
 
@@ -251,6 +266,34 @@ class InstructionList {
     return InsertBefore(to_insert, min_position_item);
   }
 
+  void InsertAfterInstructions(Item* to_insert,
+                               absl::Span<Item* const> after_instructions) {
+    VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
+            << " after {"
+            << absl::StrJoin(after_instructions, ", ",
+                             [](string* out, Item* item) {
+                               absl::StrAppend(out, item->instruction->name());
+                             })
+            << "}";
+
+    // Find the max position number of any instruction in
+    // 'after_instructions'.
+    CHECK(!after_instructions.empty());
+    Item* max_position_item = nullptr;
+    for (Item* item : after_instructions) {
+      if (max_position_item == nullptr ||
+          item->position > max_position_item->position) {
+        max_position_item = item;
+      }
+    }
+    if (max_position_item->next == nullptr) {
+      InsertAfter(to_insert, max_position_item);
+
+    } else {
+      InsertBeforeInstructions(to_insert, {max_position_item->next});
+    }
+  }
+
   void Blacklist(const HloInstruction* inst) {
     GetItem(inst)->blacklisted = true;
   }
@@ -276,6 +319,24 @@ class InstructionList {
     item->position = before->position;
   }
 
+  void InsertAfter(Item* item, Item* after) {
+    VLOG(3) << "InsertAfter: " << item->instruction->name() << " after "
+            << after->instruction->name();
+    // Insert new item into linked list.
+    item->next = after->next;
+    item->prev = after;
+
+    after->next = item;
+    if (item->next != nullptr) {
+      item->next->prev = item;
+    }
+
+    // Assign the same position number to the newly added instruction as
+    // 'before'. This guarantees monotonicity of the position numbers, but not
+    // uniqueness.
+    item->position = after->position;
+  }
+
   Item* first_;
 
   // Item for each instruction.
@@ -327,6 +388,7 @@ class MemoryUsageTracker {
   MemoryUsageTracker(
       const HloComputation* computation,
       const HloRematerialization::ShapeSizeFunction& size_function,
+      const HloRematerialization::CompactShapeFunction& compact_shape_function,
       const TuplePointsToAnalysis& points_to_analysis,
       const InstructionList& instruction_list);
 
@@ -338,6 +400,22 @@ class MemoryUsageTracker {
   // EndInstruction memory for dead operand(s) is freed.
   Status BeginInstruction(Item* item);
 
+  int64 RematerializationCost(const HloInstruction* instruction,
+                              int64 memory_reduced, int64 memory_limit_bytes) {
+    // If none of the users of 'instruction' have been placed in the sequence
+    // (as tracked by memory_tracker), then rematerialization of 'instruction'
+    // is a zero-cost move of 'instruction' in the sequence.
+    if (!absl::c_any_of(
+            instruction->users(),
+            [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
+      return 0;
+    }
+
+    CHECK_GT(memory_reduced, 0);
+    // Return the inverse of the benefit of rematerialization.
+    return memory_limit_bytes / memory_reduced;
+  }
+
   // Finishes the placement of the current instruction. This frees any dead
   // operands or dead result of the instruction. This must be called after
   // each call to BeginInstruction.
@@ -347,17 +425,28 @@ class MemoryUsageTracker {
   // if the given instruction is rematerialized.
   int64 MemoryReducedIfRematerialized(Item* item) const;
 
+  // Returns the number of bytes that the current memory usage will be reduced
+  // if the given instruction is compact.
+  int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
+
   // Returns the number of bytes that the current memory usage will be reduced
   // by if the given sequence of instructions is rematerialized.
   int64 MemoryReducedIfRematerialized(const absl::Span<Item*>& items) const;
 
+  Status AddCompressInstructions(Item* original_item, Item* compressed_item,
+                                 Item* uncompressed_item);
+
   // Adjusts memory usage to account for the rematerialization of
   // original_item for all remaining unplaced uses. The rematerialization
   // is remat_item. This method should be called after the HLO graph has
-  // been transformed (rematerialization instruction created and connected to
-  // uses).
+  // been transformed (rematerialization instruction created and connected
+  // to uses).
   Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
 
+  std::pair<Item*, RematStrategy> PickRematerializationCandidate(
+      const InstructionList& instruction_list, int64 memory_limit_bytes,
+      absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
+
   // Returns whether the given instruction has been placed (BeginInstruction
   // has been called with 'instruction' as the argument).
   bool IsPlaced(const HloInstruction* instruction) const {
@@ -390,6 +479,9 @@ class MemoryUsageTracker {
     // The materialized size of the buffer in bytes.
     const int64 size;
 
+    // Shape of the buffer.
+    Shape shape;
+
     // Whether this buffer is live-out of the computation.
     bool live_out;
 
@@ -412,19 +504,21 @@ class MemoryUsageTracker {
     }
   };
 
+  // Get the compact shape of given hlo instruction. An internal cache is used
+  // to avoid computing the shape multiple times.
+  StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
+
   // Creates a Buffer representing the given logical buffer. The buffer is added
   // to buffers_ and a reference is returned.
   Buffer& CreateBufferFromLogicalBuffer(
       const LogicalBuffer* logical_buffer,
-      const TuplePointsToAnalysis& points_to_analysis,
-      const HloRematerialization::ShapeSizeFunction& size_function,
-      bool live_out) {
+      const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
     bool has_indirect_uses = false;
     ItemList users = GetUsers(instruction_list_, logical_buffer,
                               points_to_analysis, &has_indirect_uses);
     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
-                     size_function(logical_buffer->shape()), std::move(users),
-                     live_out, has_indirect_uses);
+                     logical_buffer->shape(), std::move(users), live_out,
+                     has_indirect_uses);
   }
 
   // Create a new buffer representing a rematerialization of given buffer for
@@ -438,7 +532,7 @@ class MemoryUsageTracker {
     for (Item* use : rematerialized_uses) {
       CHECK(!use->placed) << use->instruction->name();
     }
-    return NewBuffer(remat_item, original_buffer.size,
+    return NewBuffer(remat_item, original_buffer.shape,
                      std::move(rematerialized_uses), /*live_out=*/false,
                      /*has_indirect_uses=*/false);
   }
@@ -449,7 +543,8 @@ class MemoryUsageTracker {
   // different computation.
   int64 AllocatedSize(BufferId buffer_id) const {
     const Buffer& buffer = buffers_.at(buffer_id);
-    HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
+    HloInstruction* inst = buffer.defining_instruction->instruction;
+    HloOpcode def_opcode = inst->opcode();
     if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
       return 0;
     } else {
@@ -482,12 +577,12 @@ class MemoryUsageTracker {
   }
 
   // Create a new buffer, add it to buffers_, and return a reference.
-  Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
-                    bool live_out, bool has_indirect_uses) {
+  Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
+                    ItemList&& users, bool live_out, bool has_indirect_uses) {
     int buffer_id = buffers_.size();
-    buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
-                              has_indirect_uses, users,
-                              static_cast<int64>(users.size())});
+    buffers_.push_back(Buffer{
+        buffer_id, defining_instruction, size_function_(shape), shape, live_out,
+        has_indirect_uses, users, static_cast<int64>(users.size())});
     return buffers_.back();
   }
 
@@ -498,6 +593,16 @@ class MemoryUsageTracker {
   // (BeginInstruction/EndInstruction calls).
   const InstructionList& instruction_list_;
 
+  // Size function returns the bytes of a given buffer.
+  const HloRematerialization::ShapeSizeFunction& size_function_;
+
+  // Converts a shape into compact form, returns the same shape if a shape is
+  // already considered compact.
+  const HloRematerialization::CompactShapeFunction& compact_shape_function_;
+
+  // A map that caches existing known compact shape for each instruction.
+  absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
+
   // Memory usage at the currently placed instruction.
   int64 memory_usage_ = 0;
 
@@ -512,9 +617,13 @@ class MemoryUsageTracker {
 MemoryUsageTracker::MemoryUsageTracker(
     const HloComputation* computation,
     const HloRematerialization::ShapeSizeFunction& size_function,
+    const HloRematerialization::CompactShapeFunction& compact_shape_function,
     const TuplePointsToAnalysis& points_to_analysis,
     const InstructionList& instruction_list)
-    : computation_(computation), instruction_list_(instruction_list) {
+    : computation_(computation),
+      instruction_list_(instruction_list),
+      size_function_(size_function),
+      compact_shape_function_(compact_shape_function) {
   PointsToSet::BufferSet live_out_set =
       points_to_analysis.GetPointsToSet(computation_->root_instruction())
           .CreateFlattenedSet();
@@ -556,7 +665,7 @@ MemoryUsageTracker::MemoryUsageTracker(
         }
       } else {
         buffer = &CreateBufferFromLogicalBuffer(
-            logical_buffer, points_to_analysis, size_function,
+            logical_buffer, points_to_analysis,
             ContainsKey(live_out_set, logical_buffer));
         item->buffers_defined.push_back(buffer->id);
         for (Item* user : buffer->users) {
@@ -566,6 +675,14 @@ MemoryUsageTracker::MemoryUsageTracker(
 
       logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
     }
+
+    // Trace the output of each instruction. This is so that we can properly
+    // track which outputs does GTEs have.
+    for (const LogicalBuffer* logical_buffer :
+         points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
+      item->buffers_output.push_back(
+          logical_buffer_to_buffer_id[logical_buffer]);
+    }
   }
   XLA_VLOG_LINES(10, ToString());
   DCHECK(Check());
@@ -637,6 +754,29 @@ Status MemoryUsageTracker::EndInstruction() {
   return Status::OK();
 }
 
+int64 MemoryUsageTracker::MemoryReducedIfCompressed(
+    Item* item, const Shape& compact_shape) const {
+  CHECK_NE(in_progress_item_, nullptr);
+  if (!item->placed || item == in_progress_item_) {
+    return 0;
+  }
+
+  int64 memory_reduced = 0;
+
+  // We only compress a single piece of an output at one time.
+  CHECK_EQ(item->buffers_output.size(), 1);
+  BufferId buffer_id = item->buffers_output[0];
+  if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
+    const Buffer& buffer = buffers_.at(buffer_id);
+    memory_reduced += buffer.size;
+
+    int64 compact_shape_size = size_function_(compact_shape);
+    // Account for buffers that are compress after instruction.
+    memory_reduced -= compact_shape_size;
+  }
+  return memory_reduced;
+}
+
 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
   CHECK_NE(in_progress_item_, nullptr);
   if (!item->placed || item == in_progress_item_) {
@@ -736,6 +876,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
   return memory_reduced;
 }
 
+Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
+                                                   Item* compressed_item,
+                                                   Item* uncompressed_item) {
+  // Original buffer is now dead.
+  memory_usage_ -= size_function_(original_item->instruction->shape());
+  // Compressed buffer is now alive.
+  memory_usage_ += size_function_(compressed_item->instruction->shape());
+
+  ItemList placed_users;
+  ItemList unplaced_users;
+  CHECK_EQ(original_item->buffers_output.size(), 1);
+  BufferId original_buffer_id = original_item->buffers_output[0];
+  Buffer& original_buffer = buffers_.at(original_buffer_id);
+  for (Item* user : original_buffer.users) {
+    if (user->placed) {
+      CHECK(IsFinished(user)) << user->instruction->name();
+      placed_users.push_back(user);
+    } else {
+      unplaced_users.push_back(user);
+    }
+  }
+  original_buffer.users = std::move(placed_users);
+  original_buffer.unfinished_user_count = 0;
+  original_buffer.users.push_back(compressed_item);
+  Buffer& compressed_buffer =
+      NewBuffer(compressed_item, compressed_item->instruction->shape(),
+                {uncompressed_item}, /*live_out=*/false,
+                /*has_indirect_uses=*/false);
+  compressed_item->buffers_used = original_item->buffers_output;
+  compressed_item->buffers_output = {compressed_buffer.id};
+  compressed_item->buffers_defined.push_back(compressed_buffer.id);
+
+  Buffer& uncompressed_buffer =
+      NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
+                std::move(unplaced_users), /*live_out=*/false,
+                /*has_indirect_uses=*/false);
+
+  uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
+  uncompressed_item->buffers_output = {uncompressed_buffer.id};
+  uncompressed_item->buffers_defined = {uncompressed_buffer.id};
+
+  for (Item* user : uncompressed_buffer.users) {
+    BufferIdList& buffers_used = user->buffers_used;
+    std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
+                 uncompressed_buffer.id);
+  }
+
+  return Status::OK();
+}
+
 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
                                                         Item* remat_item) {
   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
@@ -831,6 +1021,17 @@ string MemoryUsageTracker::ToString() const {
   return output;
 }
 
+StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
+  auto it = compact_shape_.find(hlo);
+  if (it != compact_shape_.end()) {
+    return it->second;
+  }
+  const Shape& original_shape = hlo->shape();
+  TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
+  compact_shape_[hlo] = min_shape;
+  return min_shape;
+}
+
 bool MemoryUsageTracker::Check() const {
   auto elements_are_unique = [](const BufferIdList& vec) {
     return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
@@ -917,12 +1118,15 @@ int64 RematerializationCost(const HloInstruction* instruction,
 // candidate which reduce memory use at the program point of the current
 // instruction as indicated by memory_tracker. nullptr is returned if no
 // candidate can be found.
-Item* PickRematerializationCandidate(
-    const MemoryUsageTracker& memory_tracker,
+std::pair<Item*, RematStrategy>
+MemoryUsageTracker::PickRematerializationCandidate(
     const InstructionList& instruction_list, int64 memory_limit_bytes,
     absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
   Item* best_item = nullptr;
   int64 best_cost = 0;
+  RematStrategy best_strategy;
+
+  VLOG(5) << "Picking candidate";
 
   // TODO(b/35244891): This is currently quadratic in the number of HLO
   // instructions.
@@ -947,44 +1151,215 @@ Item* PickRematerializationCandidate(
     if (!CanBeRematerialized(candidate, remat_able)) {
       VLOG(5) << "candidate " << candidate->name()
               << " not viable: is not rematerializable";
+
       continue;
     }
 
-    // If any of the candidate's control successor has been placed, we need to
-    // skip this candidate. Otherwise we will violate control dependency.
-    bool control_successor_placed =
-        std::any_of(candidate->control_successors().begin(),
-                    candidate->control_successors().end(),
-                    [&memory_tracker](const HloInstruction* inst) {
-                      return memory_tracker.IsPlaced(inst);
-                    });
+    if (item->buffers_output.size() == 1) {
+      // Only consider compressing single output instruction.
+      const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
+
+      if (item->placed && item != in_progress_item_ &&
+          !output_buffer.live_out) {
+        const Shape& original_shape = item->instruction->shape();
+        if (original_shape.IsArray()) {
+          Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
+          const int64 memory_reduced =
+              MemoryReducedIfCompressed(item, compact_shape);
+          if (memory_reduced > 0) {
+            const int64 cost = memory_limit_bytes / memory_reduced;
+            if (best_item == nullptr || cost < best_cost) {
+              VLOG(3) << "candidate " << candidate->name() << "("
+                      << candidate->ToShortString() << ")"
+                      << " now best when compressed into "
+                      << compact_shape.ToString(true);
+              RematStrategy strategy;
+              strategy.kind = RematStrategy::kCompress;
+              best_strategy = strategy;
+              best_strategy.compact_shape = compact_shape;
+              best_item = item;
+              best_cost = cost;
+            }
+          }
+        }
+      }
+    }
+
+    // If any of the candidate's control successor has been placed, we need
+    // to skip this candidate. Otherwise we will violate control dependency.
+    bool control_successor_placed = std::any_of(
+        candidate->control_successors().begin(),
+        candidate->control_successors().end(),
+        [this](const HloInstruction* inst) { return IsPlaced(inst); });
 
     if (control_successor_placed) {
       continue;
     }
 
-    const int64 memory_reduced =
-        memory_tracker.MemoryReducedIfRematerialized(item);
+    const int64 memory_reduced = MemoryReducedIfRematerialized(item);
 
-    if (memory_reduced <= 0) {
-      VLOG(5) << "candidate " << candidate->name()
-              << " memory reduced = " << memory_reduced << " <=  0";
-      continue;
-    }
+    if (memory_reduced > 0) {
+      const int cost =
+          RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
 
-    const int cost = RematerializationCost(candidate, memory_tracker,
-                                           memory_reduced, memory_limit_bytes);
+      VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
+              << memory_reduced << ", cost per byte " << cost;
 
-    VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
-            << memory_reduced << ", cost per byte " << cost;
-
-    if (best_item == nullptr || cost < best_cost) {
-      VLOG(5) << "candidate " << candidate->name() << " now best";
-      best_item = item;
-      best_cost = cost;
+      if (best_item == nullptr || cost < best_cost) {
+        VLOG(5) << "candidate " << candidate->name() << " now best";
+        best_strategy.kind = RematStrategy::kRecompute;
+        best_item = item;
+        best_cost = cost;
+      }
     }
   }
-  return best_item;
+  return {best_item, best_strategy};
+}
+
+StatusOr<int64> RematerializeInstruction(
+    MemoryUsageTracker* memory_tracker, Item* best_item,
+    absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
+    InstructionList* instruction_list) {
+  HloInstruction* best = best_item->instruction;
+  VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
+          << HumanReadableNumBytes(
+                 memory_tracker->MemoryReducedIfRematerialized(best_item))
+          << ")";
+
+  int64 net_instructions_added = 0;
+
+  HloComputation* computation = best->parent();
+
+  HloInstruction* remat =
+      computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
+
+  // Add control dependencies to the new operation.
+  for (auto successor : best->control_successors()) {
+    TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
+  }
+  for (auto predecessor : best->control_predecessors()) {
+    TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
+  }
+
+  Item* remat_item = instruction_list->CreateItem(remat);
+
+  // Replace each remaining use of 'best' with the rematerialization.
+  std::vector<HloInstruction*> best_users_copy = best->users();
+  for (HloInstruction* user : best_users_copy) {
+    if (!memory_tracker->IsPlaced(user)) {
+      VLOG(2) << "  Replacing use of " << best->name() << " in " << user->name()
+              << " with " << remat->name();
+      TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
+    }
+  }
+
+  // Account for the rematerialization in the memory tracker.
+  TF_RETURN_IF_ERROR(
+      memory_tracker->AddRematerializedInstruction(best_item, remat_item));
+
+  // Insert rematerialized instruction right before the earliest unplaced
+  // use of the instruction *and* the earliest unplaced last use of any
+  // operands of remat. Unplaced uses of the remat's operands are included
+  // because we don't want to extend the live range of remat's operands as
+  // this could increase memory usage.
+  ItemList place_before;
+  for (auto user : remat->users()) {
+    place_before.push_back(instruction_list->GetItem(user));
+  }
+  for (auto* operand : remat->operands()) {
+    for (auto* operand_user : operand->users()) {
+      if (operand_user != remat) {
+        Item* operand_user_item = instruction_list->GetItem(operand_user);
+        if (!operand_user_item->placed) {
+          place_before.push_back(operand_user_item);
+        }
+      }
+    }
+  }
+  // Insert rematerialized instruction before any of its successors to
+  // preserve ordering regarding control dependency.
+  for (auto successor : remat->control_successors()) {
+    Item* successor_item = instruction_list->GetItem(successor);
+    // Assert to make sure we never remat an operation with control
+    // successor already placed.
+    CHECK(!successor_item->placed) << successor_item->instruction->name();
+    place_before.push_back(successor_item);
+  }
+  instruction_list->InsertBeforeInstructions(remat_item, place_before);
+
+  // If the rematerialized instruction is dead then rematerialization is
+  // essentially a move. Don't delete the instruction now because we don't
+  // want duplicate HloInstruction* values during the course of the
+  // transformation because we keep maps with HloInstruction* values as
+  // keys.
+  if (best->users().empty()) {
+    VLOG(2) << best->name() << " is now dead";
+    if (ContainsKey(*remat_move_instructions, best)) {
+      // Previously, 'best' was a rematerialization which killed the
+      // instruction it was a copying of. Now 'remat' is a rematerialization
+      // of 'best' and kills 'best'. Stop rematerializing this instruction
+      // to avoid an infinite loop.
+      instruction_list->Blacklist(remat);
+    }
+    remat_move_instructions->insert(remat);
+
+  } else {
+    net_instructions_added++;
+  }
+  return net_instructions_added;
+}
+
+StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
+                                    Item* best_item, const Shape& compact_shape,
+                                    InstructionList* instruction_list) {
+  HloInstruction* best = best_item->instruction;
+  VLOG(5) << "Transposing instruction " << best->name() << " (saving "
+          << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
+                 best_item, compact_shape))
+          << ") to" << compact_shape.ToString(true);
+
+  HloComputation* computation = best->parent();
+
+  HloInstruction* compressed = computation->AddInstruction(
+      HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
+
+  HloInstruction* uncompressed = computation->AddInstruction(
+      HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
+
+  Item* compressed_item = instruction_list->CreateItem(compressed);
+  compressed_item->placed = true;
+
+  Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
+
+  // Replace each remaining use of 'best' with the uncompressed.
+  std::vector<HloInstruction*> best_users_copy = best->users();
+  for (HloInstruction* user : best_users_copy) {
+    if (!memory_tracker->IsPlaced(user)) {
+      VLOG(5) << "  Replacing use of " << best->name() << " in " << user->name()
+              << " with " << uncompressed->name();
+      TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
+    }
+  }
+
+  // Account for the rematerialization in the memory tracker.
+  TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
+      best_item, compressed_item, uncompressed_item));
+
+  // Insert rematerialized instruction right before the earliest unplaced
+  // use of the instruction *and* the earliest unplaced last use of any
+  // operands of remat. Unplaced uses of the remat's operands are included
+  // because we don't want to extend the live range of remat's operands as
+  // this could increase memory usage.
+  ItemList place_before;
+  for (auto user : uncompressed->users()) {
+    place_before.push_back(instruction_list->GetItem(user));
+  }
+
+  instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
+
+  instruction_list->InsertAfterInstructions(compressed_item, {best_item});
+
+  return 2;
 }
 
 }  // namespace
@@ -993,7 +1368,8 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
     const HloComputation* computation,
     const HloInstructionSequence& order) const {
   InstructionList instruction_list(order);
-  MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
+  MemoryUsageTracker tracker(computation, size_function_,
+                             compact_shape_function_, *points_to_analysis_,
                              instruction_list);
   int64 peak_memory = tracker.memory_usage();
   for (auto* item = instruction_list.first(); item != nullptr;
@@ -1037,6 +1413,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
 
   InstructionList instruction_list(schedule->sequence(computation));
   MemoryUsageTracker memory_tracker(computation, size_function_,
+                                    compact_shape_function_,
                                     *points_to_analysis_, instruction_list);
   bool changed = false;
 
@@ -1086,8 +1463,11 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
                                        callee_usage)
               << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
 
-      Item* best_item = PickRematerializationCandidate(
-          memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
+      Item* best_item;
+      RematStrategy best_strategy;
+      std::tie(best_item, best_strategy) =
+          memory_tracker.PickRematerializationCandidate(
+              instruction_list, memory_limit_bytes, &remat_able);
 
       if (best_item == nullptr) {
         VLOG(3) << "Unable to find rematerialization candidate at program "
@@ -1106,81 +1486,19 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
       changed = true;
       remat_count++;
 
-      HloInstruction* remat =
-          computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
-
-      // Add control dependencies to the new operation.
-      for (auto successor : best->control_successors()) {
-        TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
-      }
-      for (auto predecessor : best->control_predecessors()) {
-        TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
-      }
-
-      Item* remat_item = instruction_list.CreateItem(remat);
-
-      // Replace each remaining use of 'best' with the rematerialization.
-      std::vector<HloInstruction*> best_users_copy = best->users();
-      for (HloInstruction* user : best_users_copy) {
-        if (!memory_tracker.IsPlaced(user)) {
-          VLOG(2) << "  Replacing use of " << best->name() << " in "
-                  << user->name() << " with " << remat->name();
-          TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
-        }
-      }
-
-      // Account for the rematerialization in the memory tracker.
-      TF_RETURN_IF_ERROR(
-          memory_tracker.AddRematerializedInstruction(best_item, remat_item));
-
-      // Insert rematerialized instruction right before the earliest unplaced
-      // use of the instruction *and* the earliest unplaced last use of any
-      // operands of remat. Unplaced uses of the remat's operands are included
-      // because we don't want to extend the live range of remat's operands as
-      // this could increase memory usage.
-      ItemList place_before;
-      for (auto user : remat->users()) {
-        place_before.push_back(instruction_list.GetItem(user));
-      }
-      for (auto* operand : remat->operands()) {
-        for (auto* operand_user : operand->users()) {
-          if (operand_user != remat) {
-            Item* operand_user_item = instruction_list.GetItem(operand_user);
-            if (!operand_user_item->placed) {
-              place_before.push_back(operand_user_item);
-            }
-          }
-        }
-      }
-      // Insert rematerialized instruction before any of its successors to
-      // preserve ordering regarding control dependency.
-      for (auto successor : remat->control_successors()) {
-        Item* successor_item = instruction_list.GetItem(successor);
-        // Assert to make sure we never remat an operation with control
-        // successor already placed.
-        CHECK(!successor_item->placed) << successor_item->instruction->name();
-        place_before.push_back(successor_item);
-      }
-      instruction_list.InsertBeforeInstructions(remat_item, place_before);
-
-      // If the rematerialized instruction is dead then rematerialization is
-      // essentially a move. Don't delete the instruction now because we don't
-      // want duplicate HloInstruction* values during the course of the
-      // transformation because we keep maps with HloInstruction* values as
-      // keys.
-      if (best->users().empty()) {
-        VLOG(2) << best->name() << " is now dead";
-        if (ContainsKey(remat_move_instructions, best)) {
-          // Previously, 'best' was a rematerialization which killed the
-          // instruction it was a copying of. Now 'remat' is a rematerialization
-          // of 'best' and kills 'best'. Stop rematerializing this instruction
-          // to avoid an infinite loop.
-          instruction_list.Blacklist(remat);
-        }
-        remat_move_instructions.insert(remat);
+      int64 added_instruction = 0;
+      if (best_strategy.kind == RematStrategy::kCompress) {
+        TF_ASSIGN_OR_RETURN(added_instruction,
+                            CompressInstruction(&memory_tracker, best_item,
+                                                best_strategy.compact_shape,
+                                                &instruction_list));
       } else {
-        net_instructions_added++;
+        TF_ASSIGN_OR_RETURN(added_instruction,
+                            RematerializeInstruction(&memory_tracker, best_item,
+                                                     &remat_move_instructions,
+                                                     &instruction_list));
       }
+      net_instructions_added += added_instruction;
 
       VLOG(1) << "memory_usage after rematerialization = "
               << HumanReadableNumBytes(memory_tracker.memory_usage());
@@ -1357,7 +1675,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
     sizes_->after_bytes = current_peak_memory;
   }
 
-  XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
+  XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
 
   if (current_peak_memory > memory_limit_bytes_) {
     LOG(WARNING) << absl::StrFormat(
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index ebbc2dd6b5c..9ab34b4862d 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -24,6 +24,8 @@
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/statusor.h"
 
 namespace xla {
 
@@ -38,6 +40,8 @@ class HloRematerialization : public HloModulePass {
  public:
   using ShapeSizeFunction = std::function<int64(const Shape&)>;
 
+  using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>;
+
   // Helper struct that communicates the before / after sizes for the
   // rematerialization process.
   struct RematerializationSizes {
@@ -45,6 +49,8 @@ class HloRematerialization : public HloModulePass {
     int64 after_bytes;
   };
 
+  static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
+
   // Constructor parameters:
   //
   //   size_function: Function which returns the size in bytes of the top-level
@@ -57,12 +63,20 @@ class HloRematerialization : public HloModulePass {
   //   sizes: Pointer to data structure which records the peak memory usage of
   //     the HLO module before/after rematerialization. Value are set during
   //     Run(). Can be nullptr.
-  HloRematerialization(const ShapeSizeFunction& size_function,
-                       int64 memory_limit_bytes, RematerializationSizes* sizes)
+  //
+  //   compact_shape_function: Function which returns the compact form of a
+  //   shape. If nullptr is provided, an default identity function is used.
+  explicit HloRematerialization(
+      const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
+      RematerializationSizes* sizes,
+      CompactShapeFunction compact_shape_function = nullptr)
       : size_function_(size_function),
         memory_limit_bytes_(memory_limit_bytes),
-        sizes_(sizes) {}
-  ~HloRematerialization() {}
+        sizes_(sizes),
+        compact_shape_function_(compact_shape_function == nullptr
+                                    ? DefaultCompactShapeFunction
+                                    : std::move(compact_shape_function)) {}
+  ~HloRematerialization() override = default;
 
   absl::string_view name() const override { return "rematerialization"; }
 
@@ -109,6 +123,10 @@ class HloRematerialization : public HloModulePass {
   // module before/after rematerialization
   RematerializationSizes* sizes_;
 
+  // Converts a shape into compact form, returns the same shape if a shape is
+  // already considered compact.
+  const CompactShapeFunction compact_shape_function_;
+
   // Call graph of the hlo_module.
   std::unique_ptr<CallGraph> call_graph_;
 
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 987177e40b8..dabd9d20f64 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -27,7 +27,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 
 namespace xla {
@@ -534,6 +533,142 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
 INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
                          ::testing::Values(true, false));
 
+class CompressingRematerializationTest : public RematerializationTestBase {
+ protected:
+  // A special shape size function, which pads the most minor dimension to 64.
+  static int64 ShapeSizePadMinorTo64(const Shape& shape) {
+    if (shape.IsTuple()) {
+      // Size of a tuple is 4 bytes.
+      return 4;
+    }
+    Shape descending_shape =
+        ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape);
+    int64 size =
+        ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
+    for (int64 i = 0; i < descending_shape.rank(); ++i) {
+      int64 dim = shape.dimensions(i);
+      if (i == descending_shape.rank() - 1) {
+        dim = RoundUpToNearest<int64>(dim, 64);
+      }
+      size *= dim;
+    }
+    return size;
+  }
+
+  // Swap the two most-minor dimensions if the second-minor dimension is bigger
+  // than the most-minor dimension.
+  static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
+    Shape result = shape;
+    Layout layout = result.layout();
+    int64 most_minor_index = layout.minor_to_major()[0];
+    int64 second_minor_index = layout.minor_to_major()[1];
+    int64 most_minor = result.dimensions(most_minor_index);
+    int64 second_minor = result.dimensions(second_minor_index);
+    if (most_minor < second_minor) {
+      result.set_dimensions(most_minor_index, second_minor);
+      result.set_dimensions(second_minor_index, most_minor);
+    }
+    return result;
+  }
+
+  StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+                                         HloModule* module) {
+    TF_EXPECT_OK(verifier().Run(module).status());
+    HloRematerialization remat(ShapeSizePadMinorTo64, memory_limit_bytes,
+                               /*sizes=*/nullptr, ChooseCompactLayoutForShape);
+    return remat.Run(module);
+  }
+};
+
+// Test rematerialization of a single instruction.
+TEST_F(CompressingRematerializationTest, SingleRemat) {
+  const string& hlo_string = R"(
+HloModule fusion, is_scheduled=true
+
+%add_float {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %x, f32[] %y)
+}
+
+ENTRY %entry {
+  %param.0 = f32[] parameter(0)
+  %constant = f32[] constant(0)
+  %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
+  %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
+  %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
+  %reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
+  %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/30 * 1024, module.get()));
+  EXPECT_TRUE(changed);
+  HloInstruction* broadcast =
+      module->entry_computation()->GetInstructionWithName("broadcast.0");
+  HloInstruction* reduce =
+      module->entry_computation()->GetInstructionWithName("reduce.1");
+  EXPECT_THAT(reduce,
+              op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
+}
+
+TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) {
+  const string& hlo_string = R"(
+HloModule fusion, is_scheduled=true
+
+%add_float {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %x, f32[] %y)
+}
+
+ENTRY %entry {
+  %param.0 = f32[] parameter(0)
+  %constant = f32[] constant(0)
+  %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
+  %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
+  %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
+  %reduce.1 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
+  %reduce.2 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
+  %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
+  %reduce.3 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
+  %add.2 = f32[] add(f32[] %reduce.2, f32[] %reduce.3)
+  ROOT %tuple = (f32[], f32[]) tuple (f32[] add, f32[] add.2)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module,
+      HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/30 * 1024, module.get()));
+  EXPECT_TRUE(changed);
+
+  HloInstruction* broadcast =
+      module->entry_computation()->GetInstructionWithName("broadcast.0");
+
+  // Both reduces reuse the same copy instruction.
+  HloInstruction* reduce_2 =
+      module->entry_computation()->GetInstructionWithName("reduce.2");
+
+  HloInstruction* reduce_3 =
+      module->entry_computation()->GetInstructionWithName("reduce.3");
+
+  EXPECT_THAT(reduce_2,
+              op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
+
+  EXPECT_THAT(reduce_3,
+              op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
+}
+
 }  // namespace
 
 }  // namespace xla