From 02f3c946f445d68f7a16f51ede587318a1d164b4 Mon Sep 17 00:00:00 2001 From: Yunxing Dai <yunxing@google.com> Date: Wed, 21 Aug 2019 17:52:00 -0700 Subject: [PATCH] Compressing rematerialization Adds a new kind of rematerialization that compresses the node into a compact form, uncompresses it back at a later program point. PiperOrigin-RevId: 264734096 --- .../xla/service/hlo_rematerialization.cc | 558 ++++++++++++++---- .../xla/service/hlo_rematerialization.h | 26 +- .../xla/service/hlo_rematerialization_test.cc | 137 ++++- 3 files changed, 596 insertions(+), 125 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index d362317495e..aa723797da1 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -100,6 +100,17 @@ bool CanBeRematerialized( using BufferId = int64; using BufferIdList = absl::InlinedVector<BufferId, 3>; +struct RematStrategy { + enum { + // Recompute the node at a later program point. + kRecompute, + // Change the layout into a compact form and uncompress it back at a later + // program point. + kCompress, + } kind; + Shape compact_shape; +}; + // We wrap HloInstruction* with an Item that holds auxiliary // per-instruction state. struct Item { @@ -117,6 +128,10 @@ struct Item { // The buffers defined by this instruction. BufferIdList buffers_defined; + // Output buffers of this instruction. This is used to track outputs by GTE + // instructions (where the instruction doesn't define a buffer). + BufferIdList buffers_output; + // The buffers used by this instruction. BufferIdList buffers_used; @@ -251,6 +266,34 @@ class InstructionList { return InsertBefore(to_insert, min_position_item); } + void InsertAfterInstructions(Item* to_insert, + absl::Span<Item* const> after_instructions) { + VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name() + << " after {" + << absl::StrJoin(after_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) + << "}"; + + // Find the max position number of any instruction in + // 'after_instructions'. + CHECK(!after_instructions.empty()); + Item* max_position_item = nullptr; + for (Item* item : after_instructions) { + if (max_position_item == nullptr || + item->position > max_position_item->position) { + max_position_item = item; + } + } + if (max_position_item->next == nullptr) { + InsertAfter(to_insert, max_position_item); + + } else { + InsertBeforeInstructions(to_insert, {max_position_item->next}); + } + } + void Blacklist(const HloInstruction* inst) { GetItem(inst)->blacklisted = true; } @@ -276,6 +319,24 @@ class InstructionList { item->position = before->position; } + void InsertAfter(Item* item, Item* after) { + VLOG(3) << "InsertAfter: " << item->instruction->name() << " after " + << after->instruction->name(); + // Insert new item into linked list. + item->next = after->next; + item->prev = after; + + after->next = item; + if (item->next != nullptr) { + item->next->prev = item; + } + + // Assign the same position number to the newly added instruction as + // 'before'. This guarantees monotonicity of the position numbers, but not + // uniqueness. + item->position = after->position; + } + Item* first_; // Item for each instruction. @@ -327,6 +388,7 @@ class MemoryUsageTracker { MemoryUsageTracker( const HloComputation* computation, const HloRematerialization::ShapeSizeFunction& size_function, + const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, const InstructionList& instruction_list); @@ -338,6 +400,22 @@ class MemoryUsageTracker { // EndInstruction memory for dead operand(s) is freed. Status BeginInstruction(Item* item); + int64 RematerializationCost(const HloInstruction* instruction, + int64 memory_reduced, int64 memory_limit_bytes) { + // If none of the users of 'instruction' have been placed in the sequence + // (as tracked by memory_tracker), then rematerialization of 'instruction' + // is a zero-cost move of 'instruction' in the sequence. + if (!absl::c_any_of( + instruction->users(), + [this](const HloInstruction* inst) { return IsPlaced(inst); })) { + return 0; + } + + CHECK_GT(memory_reduced, 0); + // Return the inverse of the benefit of rematerialization. + return memory_limit_bytes / memory_reduced; + } + // Finishes the placement of the current instruction. This frees any dead // operands or dead result of the instruction. This must be called after // each call to BeginInstruction. @@ -347,17 +425,28 @@ class MemoryUsageTracker { // if the given instruction is rematerialized. int64 MemoryReducedIfRematerialized(Item* item) const; + // Returns the number of bytes that the current memory usage will be reduced + // if the given instruction is compact. + int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const; + // Returns the number of bytes that the current memory usage will be reduced // by if the given sequence of instructions is rematerialized. int64 MemoryReducedIfRematerialized(const absl::Span<Item*>& items) const; + Status AddCompressInstructions(Item* original_item, Item* compressed_item, + Item* uncompressed_item); + // Adjusts memory usage to account for the rematerialization of // original_item for all remaining unplaced uses. The rematerialization // is remat_item. This method should be called after the HLO graph has - // been transformed (rematerialization instruction created and connected to - // uses). + // been transformed (rematerialization instruction created and connected + // to uses). Status AddRematerializedInstruction(Item* original_item, Item* remat_item); + std::pair<Item*, RematStrategy> PickRematerializationCandidate( + const InstructionList& instruction_list, int64 memory_limit_bytes, + absl::flat_hash_map<const HloInstruction*, bool>* remat_able); + // Returns whether the given instruction has been placed (BeginInstruction // has been called with 'instruction' as the argument). bool IsPlaced(const HloInstruction* instruction) const { @@ -390,6 +479,9 @@ class MemoryUsageTracker { // The materialized size of the buffer in bytes. const int64 size; + // Shape of the buffer. + Shape shape; + // Whether this buffer is live-out of the computation. bool live_out; @@ -412,19 +504,21 @@ class MemoryUsageTracker { } }; + // Get the compact shape of given hlo instruction. An internal cache is used + // to avoid computing the shape multiple times. + StatusOr<Shape> GetCompactShape(const HloInstruction* hlo); + // Creates a Buffer representing the given logical buffer. The buffer is added // to buffers_ and a reference is returned. Buffer& CreateBufferFromLogicalBuffer( const LogicalBuffer* logical_buffer, - const TuplePointsToAnalysis& points_to_analysis, - const HloRematerialization::ShapeSizeFunction& size_function, - bool live_out) { + const TuplePointsToAnalysis& points_to_analysis, bool live_out) { bool has_indirect_uses = false; ItemList users = GetUsers(instruction_list_, logical_buffer, points_to_analysis, &has_indirect_uses); return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), - size_function(logical_buffer->shape()), std::move(users), - live_out, has_indirect_uses); + logical_buffer->shape(), std::move(users), live_out, + has_indirect_uses); } // Create a new buffer representing a rematerialization of given buffer for @@ -438,7 +532,7 @@ class MemoryUsageTracker { for (Item* use : rematerialized_uses) { CHECK(!use->placed) << use->instruction->name(); } - return NewBuffer(remat_item, original_buffer.size, + return NewBuffer(remat_item, original_buffer.shape, std::move(rematerialized_uses), /*live_out=*/false, /*has_indirect_uses=*/false); } @@ -449,7 +543,8 @@ class MemoryUsageTracker { // different computation. int64 AllocatedSize(BufferId buffer_id) const { const Buffer& buffer = buffers_.at(buffer_id); - HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode(); + HloInstruction* inst = buffer.defining_instruction->instruction; + HloOpcode def_opcode = inst->opcode(); if (buffer.live_out || def_opcode == HloOpcode::kParameter) { return 0; } else { @@ -482,12 +577,12 @@ class MemoryUsageTracker { } // Create a new buffer, add it to buffers_, and return a reference. - Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users, - bool live_out, bool has_indirect_uses) { + Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, + ItemList&& users, bool live_out, bool has_indirect_uses) { int buffer_id = buffers_.size(); - buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, - has_indirect_uses, users, - static_cast<int64>(users.size())}); + buffers_.push_back(Buffer{ + buffer_id, defining_instruction, size_function_(shape), shape, live_out, + has_indirect_uses, users, static_cast<int64>(users.size())}); return buffers_.back(); } @@ -498,6 +593,16 @@ class MemoryUsageTracker { // (BeginInstruction/EndInstruction calls). const InstructionList& instruction_list_; + // Size function returns the bytes of a given buffer. + const HloRematerialization::ShapeSizeFunction& size_function_; + + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + const HloRematerialization::CompactShapeFunction& compact_shape_function_; + + // A map that caches existing known compact shape for each instruction. + absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_; + // Memory usage at the currently placed instruction. int64 memory_usage_ = 0; @@ -512,9 +617,13 @@ class MemoryUsageTracker { MemoryUsageTracker::MemoryUsageTracker( const HloComputation* computation, const HloRematerialization::ShapeSizeFunction& size_function, + const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, const InstructionList& instruction_list) - : computation_(computation), instruction_list_(instruction_list) { + : computation_(computation), + instruction_list_(instruction_list), + size_function_(size_function), + compact_shape_function_(compact_shape_function) { PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); @@ -556,7 +665,7 @@ MemoryUsageTracker::MemoryUsageTracker( } } else { buffer = &CreateBufferFromLogicalBuffer( - logical_buffer, points_to_analysis, size_function, + logical_buffer, points_to_analysis, ContainsKey(live_out_set, logical_buffer)); item->buffers_defined.push_back(buffer->id); for (Item* user : buffer->users) { @@ -566,6 +675,14 @@ MemoryUsageTracker::MemoryUsageTracker( logical_buffer_to_buffer_id[logical_buffer] = buffer->id; } + + // Trace the output of each instruction. This is so that we can properly + // track which outputs does GTEs have. + for (const LogicalBuffer* logical_buffer : + points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) { + item->buffers_output.push_back( + logical_buffer_to_buffer_id[logical_buffer]); + } } XLA_VLOG_LINES(10, ToString()); DCHECK(Check()); @@ -637,6 +754,29 @@ Status MemoryUsageTracker::EndInstruction() { return Status::OK(); } +int64 MemoryUsageTracker::MemoryReducedIfCompressed( + Item* item, const Shape& compact_shape) const { + CHECK_NE(in_progress_item_, nullptr); + if (!item->placed || item == in_progress_item_) { + return 0; + } + + int64 memory_reduced = 0; + + // We only compress a single piece of an output at one time. + CHECK_EQ(item->buffers_output.size(), 1); + BufferId buffer_id = item->buffers_output[0]; + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { + const Buffer& buffer = buffers_.at(buffer_id); + memory_reduced += buffer.size; + + int64 compact_shape_size = size_function_(compact_shape); + // Account for buffers that are compress after instruction. + memory_reduced -= compact_shape_size; + } + return memory_reduced; +} + int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const { CHECK_NE(in_progress_item_, nullptr); if (!item->placed || item == in_progress_item_) { @@ -736,6 +876,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( return memory_reduced; } +Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, + Item* compressed_item, + Item* uncompressed_item) { + // Original buffer is now dead. + memory_usage_ -= size_function_(original_item->instruction->shape()); + // Compressed buffer is now alive. + memory_usage_ += size_function_(compressed_item->instruction->shape()); + + ItemList placed_users; + ItemList unplaced_users; + CHECK_EQ(original_item->buffers_output.size(), 1); + BufferId original_buffer_id = original_item->buffers_output[0]; + Buffer& original_buffer = buffers_.at(original_buffer_id); + for (Item* user : original_buffer.users) { + if (user->placed) { + CHECK(IsFinished(user)) << user->instruction->name(); + placed_users.push_back(user); + } else { + unplaced_users.push_back(user); + } + } + original_buffer.users = std::move(placed_users); + original_buffer.unfinished_user_count = 0; + original_buffer.users.push_back(compressed_item); + Buffer& compressed_buffer = + NewBuffer(compressed_item, compressed_item->instruction->shape(), + {uncompressed_item}, /*live_out=*/false, + /*has_indirect_uses=*/false); + compressed_item->buffers_used = original_item->buffers_output; + compressed_item->buffers_output = {compressed_buffer.id}; + compressed_item->buffers_defined.push_back(compressed_buffer.id); + + Buffer& uncompressed_buffer = + NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(), + std::move(unplaced_users), /*live_out=*/false, + /*has_indirect_uses=*/false); + + uncompressed_item->buffers_used = {compressed_item->buffers_output[0]}; + uncompressed_item->buffers_output = {uncompressed_buffer.id}; + uncompressed_item->buffers_defined = {uncompressed_buffer.id}; + + for (Item* user : uncompressed_buffer.users) { + BufferIdList& buffers_used = user->buffers_used; + std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id, + uncompressed_buffer.id); + } + + return Status::OK(); +} + Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, Item* remat_item) { VLOG(3) << "AddRematerializedInstruction: original_instruction = " @@ -831,6 +1021,17 @@ string MemoryUsageTracker::ToString() const { return output; } +StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) { + auto it = compact_shape_.find(hlo); + if (it != compact_shape_.end()) { + return it->second; + } + const Shape& original_shape = hlo->shape(); + TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape)); + compact_shape_[hlo] = min_shape; + return min_shape; +} + bool MemoryUsageTracker::Check() const { auto elements_are_unique = [](const BufferIdList& vec) { return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size(); @@ -917,12 +1118,15 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate( - const MemoryUsageTracker& memory_tracker, +std::pair<Item*, RematStrategy> +MemoryUsageTracker::PickRematerializationCandidate( const InstructionList& instruction_list, int64 memory_limit_bytes, absl::flat_hash_map<const HloInstruction*, bool>* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; + RematStrategy best_strategy; + + VLOG(5) << "Picking candidate"; // TODO(b/35244891): This is currently quadratic in the number of HLO // instructions. @@ -947,44 +1151,215 @@ Item* PickRematerializationCandidate( if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; + continue; } - // If any of the candidate's control successor has been placed, we need to - // skip this candidate. Otherwise we will violate control dependency. - bool control_successor_placed = - std::any_of(candidate->control_successors().begin(), - candidate->control_successors().end(), - [&memory_tracker](const HloInstruction* inst) { - return memory_tracker.IsPlaced(inst); - }); + if (item->buffers_output.size() == 1) { + // Only consider compressing single output instruction. + const Buffer& output_buffer = buffers_.at(item->buffers_output[0]); + + if (item->placed && item != in_progress_item_ && + !output_buffer.live_out) { + const Shape& original_shape = item->instruction->shape(); + if (original_shape.IsArray()) { + Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie(); + const int64 memory_reduced = + MemoryReducedIfCompressed(item, compact_shape); + if (memory_reduced > 0) { + const int64 cost = memory_limit_bytes / memory_reduced; + if (best_item == nullptr || cost < best_cost) { + VLOG(3) << "candidate " << candidate->name() << "(" + << candidate->ToShortString() << ")" + << " now best when compressed into " + << compact_shape.ToString(true); + RematStrategy strategy; + strategy.kind = RematStrategy::kCompress; + best_strategy = strategy; + best_strategy.compact_shape = compact_shape; + best_item = item; + best_cost = cost; + } + } + } + } + } + + // If any of the candidate's control successor has been placed, we need + // to skip this candidate. Otherwise we will violate control dependency. + bool control_successor_placed = std::any_of( + candidate->control_successors().begin(), + candidate->control_successors().end(), + [this](const HloInstruction* inst) { return IsPlaced(inst); }); if (control_successor_placed) { continue; } - const int64 memory_reduced = - memory_tracker.MemoryReducedIfRematerialized(item); + const int64 memory_reduced = MemoryReducedIfRematerialized(item); - if (memory_reduced <= 0) { - VLOG(5) << "candidate " << candidate->name() - << " memory reduced = " << memory_reduced << " <= 0"; - continue; - } + if (memory_reduced > 0) { + const int cost = + RematerializationCost(candidate, memory_reduced, memory_limit_bytes); - const int cost = RematerializationCost(candidate, memory_tracker, - memory_reduced, memory_limit_bytes); + VLOG(5) << "candidate " << candidate->name() << ", memory reduced " + << memory_reduced << ", cost per byte " << cost; - VLOG(5) << "candidate " << candidate->name() << ", memory reduced " - << memory_reduced << ", cost per byte " << cost; - - if (best_item == nullptr || cost < best_cost) { - VLOG(5) << "candidate " << candidate->name() << " now best"; - best_item = item; - best_cost = cost; + if (best_item == nullptr || cost < best_cost) { + VLOG(5) << "candidate " << candidate->name() << " now best"; + best_strategy.kind = RematStrategy::kRecompute; + best_item = item; + best_cost = cost; + } } } - return best_item; + return {best_item, best_strategy}; +} + +StatusOr<int64> RematerializeInstruction( + MemoryUsageTracker* memory_tracker, Item* best_item, + absl::flat_hash_set<const HloInstruction*>* remat_move_instructions, + InstructionList* instruction_list) { + HloInstruction* best = best_item->instruction; + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << HumanReadableNumBytes( + memory_tracker->MemoryReducedIfRematerialized(best_item)) + << ")"; + + int64 net_instructions_added = 0; + + HloComputation* computation = best->parent(); + + HloInstruction* remat = + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); + + // Add control dependencies to the new operation. + for (auto successor : best->control_successors()) { + TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); + } + for (auto predecessor : best->control_predecessors()) { + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); + } + + Item* remat_item = instruction_list->CreateItem(remat); + + // Replace each remaining use of 'best' with the rematerialization. + std::vector<HloInstruction*> best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker->IsPlaced(user)) { + VLOG(2) << " Replacing use of " << best->name() << " in " << user->name() + << " with " << remat->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + } + } + + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR( + memory_tracker->AddRematerializedInstruction(best_item, remat_item)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + ItemList place_before; + for (auto user : remat->users()) { + place_before.push_back(instruction_list->GetItem(user)); + } + for (auto* operand : remat->operands()) { + for (auto* operand_user : operand->users()) { + if (operand_user != remat) { + Item* operand_user_item = instruction_list->GetItem(operand_user); + if (!operand_user_item->placed) { + place_before.push_back(operand_user_item); + } + } + } + } + // Insert rematerialized instruction before any of its successors to + // preserve ordering regarding control dependency. + for (auto successor : remat->control_successors()) { + Item* successor_item = instruction_list->GetItem(successor); + // Assert to make sure we never remat an operation with control + // successor already placed. + CHECK(!successor_item->placed) << successor_item->instruction->name(); + place_before.push_back(successor_item); + } + instruction_list->InsertBeforeInstructions(remat_item, place_before); + + // If the rematerialized instruction is dead then rematerialization is + // essentially a move. Don't delete the instruction now because we don't + // want duplicate HloInstruction* values during the course of the + // transformation because we keep maps with HloInstruction* values as + // keys. + if (best->users().empty()) { + VLOG(2) << best->name() << " is now dead"; + if (ContainsKey(*remat_move_instructions, best)) { + // Previously, 'best' was a rematerialization which killed the + // instruction it was a copying of. Now 'remat' is a rematerialization + // of 'best' and kills 'best'. Stop rematerializing this instruction + // to avoid an infinite loop. + instruction_list->Blacklist(remat); + } + remat_move_instructions->insert(remat); + + } else { + net_instructions_added++; + } + return net_instructions_added; +} + +StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker, + Item* best_item, const Shape& compact_shape, + InstructionList* instruction_list) { + HloInstruction* best = best_item->instruction; + VLOG(5) << "Transposing instruction " << best->name() << " (saving " + << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed( + best_item, compact_shape)) + << ") to" << compact_shape.ToString(true); + + HloComputation* computation = best->parent(); + + HloInstruction* compressed = computation->AddInstruction( + HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); + + HloInstruction* uncompressed = computation->AddInstruction( + HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); + + Item* compressed_item = instruction_list->CreateItem(compressed); + compressed_item->placed = true; + + Item* uncompressed_item = instruction_list->CreateItem(uncompressed); + + // Replace each remaining use of 'best' with the uncompressed. + std::vector<HloInstruction*> best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker->IsPlaced(user)) { + VLOG(5) << " Replacing use of " << best->name() << " in " << user->name() + << " with " << uncompressed->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed)); + } + } + + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions( + best_item, compressed_item, uncompressed_item)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + ItemList place_before; + for (auto user : uncompressed->users()) { + place_before.push_back(instruction_list->GetItem(user)); + } + + instruction_list->InsertBeforeInstructions(uncompressed_item, place_before); + + instruction_list->InsertAfterInstructions(compressed_item, {best_item}); + + return 2; } } // namespace @@ -993,7 +1368,8 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order) const { InstructionList instruction_list(order); - MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, + MemoryUsageTracker tracker(computation, size_function_, + compact_shape_function_, *points_to_analysis_, instruction_list); int64 peak_memory = tracker.memory_usage(); for (auto* item = instruction_list.first(); item != nullptr; @@ -1037,6 +1413,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( InstructionList instruction_list(schedule->sequence(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, + compact_shape_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1086,8 +1463,11 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( callee_usage) << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); - Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes, &remat_able); + Item* best_item; + RematStrategy best_strategy; + std::tie(best_item, best_strategy) = + memory_tracker.PickRematerializationCandidate( + instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1106,81 +1486,19 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( changed = true; remat_count++; - HloInstruction* remat = - computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - - // Add control dependencies to the new operation. - for (auto successor : best->control_successors()) { - TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); - } - for (auto predecessor : best->control_predecessors()) { - TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); - } - - Item* remat_item = instruction_list.CreateItem(remat); - - // Replace each remaining use of 'best' with the rematerialization. - std::vector<HloInstruction*> best_users_copy = best->users(); - for (HloInstruction* user : best_users_copy) { - if (!memory_tracker.IsPlaced(user)) { - VLOG(2) << " Replacing use of " << best->name() << " in " - << user->name() << " with " << remat->name(); - TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); - } - } - - // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR( - memory_tracker.AddRematerializedInstruction(best_item, remat_item)); - - // Insert rematerialized instruction right before the earliest unplaced - // use of the instruction *and* the earliest unplaced last use of any - // operands of remat. Unplaced uses of the remat's operands are included - // because we don't want to extend the live range of remat's operands as - // this could increase memory usage. - ItemList place_before; - for (auto user : remat->users()) { - place_before.push_back(instruction_list.GetItem(user)); - } - for (auto* operand : remat->operands()) { - for (auto* operand_user : operand->users()) { - if (operand_user != remat) { - Item* operand_user_item = instruction_list.GetItem(operand_user); - if (!operand_user_item->placed) { - place_before.push_back(operand_user_item); - } - } - } - } - // Insert rematerialized instruction before any of its successors to - // preserve ordering regarding control dependency. - for (auto successor : remat->control_successors()) { - Item* successor_item = instruction_list.GetItem(successor); - // Assert to make sure we never remat an operation with control - // successor already placed. - CHECK(!successor_item->placed) << successor_item->instruction->name(); - place_before.push_back(successor_item); - } - instruction_list.InsertBeforeInstructions(remat_item, place_before); - - // If the rematerialized instruction is dead then rematerialization is - // essentially a move. Don't delete the instruction now because we don't - // want duplicate HloInstruction* values during the course of the - // transformation because we keep maps with HloInstruction* values as - // keys. - if (best->users().empty()) { - VLOG(2) << best->name() << " is now dead"; - if (ContainsKey(remat_move_instructions, best)) { - // Previously, 'best' was a rematerialization which killed the - // instruction it was a copying of. Now 'remat' is a rematerialization - // of 'best' and kills 'best'. Stop rematerializing this instruction - // to avoid an infinite loop. - instruction_list.Blacklist(remat); - } - remat_move_instructions.insert(remat); + int64 added_instruction = 0; + if (best_strategy.kind == RematStrategy::kCompress) { + TF_ASSIGN_OR_RETURN(added_instruction, + CompressInstruction(&memory_tracker, best_item, + best_strategy.compact_shape, + &instruction_list)); } else { - net_instructions_added++; + TF_ASSIGN_OR_RETURN(added_instruction, + RematerializeInstruction(&memory_tracker, best_item, + &remat_move_instructions, + &instruction_list)); } + net_instructions_added += added_instruction; VLOG(1) << "memory_usage after rematerialization = " << HumanReadableNumBytes(memory_tracker.memory_usage()); @@ -1357,7 +1675,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) { sizes_->after_bytes = current_peak_memory; } - XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); + XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes_) { LOG(WARNING) << absl::StrFormat( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index ebbc2dd6b5c..9ab34b4862d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -24,6 +24,8 @@ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -38,6 +40,8 @@ class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function<int64(const Shape&)>; + using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>; + // Helper struct that communicates the before / after sizes for the // rematerialization process. struct RematerializationSizes { @@ -45,6 +49,8 @@ class HloRematerialization : public HloModulePass { int64 after_bytes; }; + static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level @@ -57,12 +63,20 @@ class HloRematerialization : public HloModulePass { // sizes: Pointer to data structure which records the peak memory usage of // the HLO module before/after rematerialization. Value are set during // Run(). Can be nullptr. - HloRematerialization(const ShapeSizeFunction& size_function, - int64 memory_limit_bytes, RematerializationSizes* sizes) + // + // compact_shape_function: Function which returns the compact form of a + // shape. If nullptr is provided, an default identity function is used. + explicit HloRematerialization( + const ShapeSizeFunction& size_function, int64 memory_limit_bytes, + RematerializationSizes* sizes, + CompactShapeFunction compact_shape_function = nullptr) : size_function_(size_function), memory_limit_bytes_(memory_limit_bytes), - sizes_(sizes) {} - ~HloRematerialization() {} + sizes_(sizes), + compact_shape_function_(compact_shape_function == nullptr + ? DefaultCompactShapeFunction + : std::move(compact_shape_function)) {} + ~HloRematerialization() override = default; absl::string_view name() const override { return "rematerialization"; } @@ -109,6 +123,10 @@ class HloRematerialization : public HloModulePass { // module before/after rematerialization RematerializationSizes* sizes_; + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + const CompactShapeFunction compact_shape_function_; + // Call graph of the hlo_module. std::unique_ptr<CallGraph> call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 987177e40b8..dabd9d20f64 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -534,6 +533,142 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest, ::testing::Values(true, false)); +class CompressingRematerializationTest : public RematerializationTestBase { + protected: + // A special shape size function, which pads the most minor dimension to 64. + static int64 ShapeSizePadMinorTo64(const Shape& shape) { + if (shape.IsTuple()) { + // Size of a tuple is 4 bytes. + return 4; + } + Shape descending_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape); + int64 size = + ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type()); + for (int64 i = 0; i < descending_shape.rank(); ++i) { + int64 dim = shape.dimensions(i); + if (i == descending_shape.rank() - 1) { + dim = RoundUpToNearest<int64>(dim, 64); + } + size *= dim; + } + return size; + } + + // Swap the two most-minor dimensions if the second-minor dimension is bigger + // than the most-minor dimension. + static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) { + Shape result = shape; + Layout layout = result.layout(); + int64 most_minor_index = layout.minor_to_major()[0]; + int64 second_minor_index = layout.minor_to_major()[1]; + int64 most_minor = result.dimensions(most_minor_index); + int64 second_minor = result.dimensions(second_minor_index); + if (most_minor < second_minor) { + result.set_dimensions(most_minor_index, second_minor); + result.set_dimensions(second_minor_index, most_minor); + } + return result; + } + + StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module) { + TF_EXPECT_OK(verifier().Run(module).status()); + HloRematerialization remat(ShapeSizePadMinorTo64, memory_limit_bytes, + /*sizes=*/nullptr, ChooseCompactLayoutForShape); + return remat.Run(module); + } +}; + +// Test rematerialization of a single instruction. +TEST_F(CompressingRematerializationTest, SingleRemat) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %constant = f32[] constant(0) + %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={} + %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0) + %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/30 * 1024, module.get())); + EXPECT_TRUE(changed); + HloInstruction* broadcast = + module->entry_computation()->GetInstructionWithName("broadcast.0"); + HloInstruction* reduce = + module->entry_computation()->GetInstructionWithName("reduce.1"); + EXPECT_THAT(reduce, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); +} + +TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %constant = f32[] constant(0) + %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={} + %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0) + %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.1 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.2 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) + %reduce.3 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add.2 = f32[] add(f32[] %reduce.2, f32[] %reduce.3) + ROOT %tuple = (f32[], f32[]) tuple (f32[] add, f32[] add.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/30 * 1024, module.get())); + EXPECT_TRUE(changed); + + HloInstruction* broadcast = + module->entry_computation()->GetInstructionWithName("broadcast.0"); + + // Both reduces reuse the same copy instruction. + HloInstruction* reduce_2 = + module->entry_computation()->GetInstructionWithName("reduce.2"); + + HloInstruction* reduce_3 = + module->entry_computation()->GetInstructionWithName("reduce.3"); + + EXPECT_THAT(reduce_2, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); + + EXPECT_THAT(reduce_3, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); +} + } // namespace } // namespace xla