From f7c7fbd40b7a66b2b060a5207e42ec567c51ae89 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Tue, 19 Jan 2021 17:41:31 -0800 Subject: [PATCH] [XLA] Extend hlo_rematerialization pass to support rematerialization of tuple producing instrs. Allow rematerialization of tuple producing instructions by extending the process we use to rematerialize bitcasts to also handle get-tuple-element'ed buffers that are not nested. This allows to rematerialize through tuples as well. PiperOrigin-RevId: 352691189 Change-Id: Ia1a7674c7e32f1c53253cd5b674abce99f87d509 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/hlo_rematerialization.cc | 274 +++++++++++------- .../xla/service/hlo_rematerialization_test.cc | 223 +++++++++++++- 3 files changed, 378 insertions(+), 120 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 65577e08d77..86e2b347329 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3757,6 +3757,7 @@ cc_library( ":call_graph", ":flatten_call_graph", ":hlo", + ":hlo_casting_utils", ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 2c77243d07f..9fef6a6cfa3 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -32,9 +32,11 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -98,6 +100,14 @@ bool CanBeRematerialized( return rematerializable; } +// Return if this is an instruction that relays the buffers it uses to its own +// users and if this is one of these instructions we support the +// rematerialization of. +bool IsSupportedIndirectUser(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kBitcast || + instruction->opcode() == HloOpcode::kGetTupleElement; +} + // Type holding a unique identifier for each Buffer object. using BufferId = int64; using BufferIdList = absl::InlinedVector; @@ -162,10 +172,13 @@ struct Item { struct ItemUse { Item* user; int64 operand_number; + absl::optional index; - ItemUse(Item* user, int64 op_num) : user(user), operand_number(op_num) {} + ItemUse(Item* user, int64 op_num, absl::optional index) + : user(user), operand_number(op_num), index(index) {} bool operator==(const ItemUse& other) const { - return user == other.user && operand_number == other.operand_number; + return user == other.user && operand_number == other.operand_number && + index == other.index; } }; @@ -449,16 +462,22 @@ UsesList GetUsers(const InstructionList& instruction_list, continue; } if (buffer_alias.instruction() != logical_buffer->instruction() && - buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) { + !IsSupportedIndirectUser(buffer_alias.instruction())) { *has_indirect_users = true; } // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. Item* user_item = instruction_list.GetItem(user); + absl::optional user_index = + logical_buffer->index().size() != 1 + ? absl::nullopt + : absl::make_optional(logical_buffer->index().back()); for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) { if (!absl::c_linear_search( - users, ItemUse{user_item, static_cast(op_idx)})) { - users.push_back(ItemUse{user_item, static_cast(op_idx)}); + users, + ItemUse{user_item, static_cast(op_idx), user_index})) { + users.push_back( + ItemUse{user_item, static_cast(op_idx), user_index}); } } } @@ -516,10 +535,6 @@ class MemoryUsageTracker { // each call to BeginInstruction. Status EndInstruction(); - // Returns the number of bytes that the current memory usage will be reduced - // if the given instruction is rematerialized. - int64 MemoryReducedIfRematerialized(Item* item) const; - // Returns the number of bytes that the current memory usage will be reduced // if the given instruction is compact. int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const; @@ -538,7 +553,7 @@ class MemoryUsageTracker { // been transformed (rematerialization instruction created and connected // to uses). Status AddRematerializedInstruction(Item* original_item, Item* remat_item, - absl::Span bitcasts); + absl::Span indirect_users); // Selects and returns the best candidate instructions for rematerialization. // A sequence of candidate instructions of length between min_block_size and @@ -612,6 +627,9 @@ class MemoryUsageTracker { // buffer aliasing (eg, tuples). bool has_indirect_uses; + // Position in the tuple this buffer definition lives in. + ShapeIndex index; + // The instructions which use this buffer. UsesList users; @@ -639,8 +657,8 @@ class MemoryUsageTracker { UsesList users = GetUsers(instruction_list_, logical_buffer, points_to_analysis, &has_indirect_uses); return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), - logical_buffer->shape(), std::move(users), live_out, - has_indirect_uses); + logical_buffer->shape(), logical_buffer->index(), + std::move(users), live_out, has_indirect_uses); } // Create a new buffer representing a rematerialization of given buffer for @@ -654,7 +672,7 @@ class MemoryUsageTracker { for (ItemUse& use : rematerialized_uses) { CHECK(!use.user->placed) << use.user->instruction->name(); } - return NewBuffer(remat_item, original_buffer.shape, + return NewBuffer(remat_item, original_buffer.shape, original_buffer.index, std::move(rematerialized_uses), /*live_out=*/false, /*has_indirect_uses=*/false); } @@ -715,7 +733,8 @@ class MemoryUsageTracker { // Create a new buffer, add it to buffers_, and return a reference. Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, - UsesList&& uses, bool live_out, bool has_indirect_uses) { + const ShapeIndex& index, UsesList&& uses, bool live_out, + bool has_indirect_uses) { int buffer_id = buffers_.size(); auto get_num_of_unique_users = [](const UsesList& uses) -> int64 { absl::flat_hash_set users_set; @@ -726,7 +745,7 @@ class MemoryUsageTracker { }; buffers_.push_back(Buffer{ buffer_id, defining_instruction, size_function_(shape), shape, live_out, - has_indirect_uses, uses, get_num_of_unique_users(uses)}); + has_indirect_uses, index, uses, get_num_of_unique_users(uses)}); return buffers_.back(); } @@ -931,51 +950,6 @@ int64 MemoryUsageTracker::MemoryReducedIfCompressed( return memory_reduced; } -int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const { - CHECK_NE(in_progress_item_, nullptr); - if (!item->placed || item == in_progress_item_) { - return 0; - } - - // TODO(b/37687140): Rematerialization can increase peak memory consumption at - // an earlier point in the program if rematerialization extends the live range - // of the operand of the instruction being rematerialized across the live - // range of the value of instruction being rematerialized. Don't rematerialize - // in this case (ie, return 0 here). - - // Compute the amount of memory reduced (if any) by rematerializing - // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer - // be live at this program point, so initially set memory_reduced to the - // size of its defined values. - int64 memory_reduced = 0; - for (BufferId buffer_id : item->buffers_defined) { - // Avoid rematerializing instructions with indirect uses as it is difficult - // to reason about liveness after rematerializing the instruction. - // TODO(b/37714814): Consider rematerializing instructions with indirect - // uses. - if (buffers_.at(buffer_id).has_indirect_uses) { - return 0; - } - - if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { - memory_reduced += AllocatedSize(buffer_id); - } - } - - // Account for any logical buffers whose live range must be extended across - // this program point. - for (BufferId buffer_id : item->buffers_used) { - if (!IsCurrentlyLive(buffer_id)) { - // This logical buffer is used by 'instruction' but is not live at this - // program point. Rematerializing 'instruction' will extend the buffer's - // live range across this program point. - memory_reduced -= AllocatedSize(buffer_id); - } - } - - return memory_reduced; -} - int64 MemoryUsageTracker::MemoryReducedIfRematerialized( absl::Span items) const { CHECK_NE(in_progress_item_, nullptr); @@ -994,17 +968,21 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( // will no longer be live at this program point, so initially set // memory_reduced to the size of its defined values. for (BufferId buffer_id : item->buffers_defined) { + const Buffer& buffer = buffers_.at(buffer_id); // Avoid rematerializing instructions with indirect uses as it is // difficult to reason about liveness after rematerializing the // instruction. // Avoid rematerializing instructions with live out buffers. + // Avoid rematerializing buffers that are in nested tuples. // TODO(mpurohit): Check why live_out buffers are an issue here. - if (buffers_.at(buffer_id).has_indirect_uses || - buffers_.at(buffer_id).live_out) { + if (buffer.has_indirect_uses || buffer.live_out || + buffer.index.size() > 1) { return 0; } - - if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { + if (IsInUse(buffer_id)) { + return 0; + } + if (IsCurrentlyLive(buffer_id)) { memory_reduced += AllocatedSize(buffer_id); } } @@ -1053,10 +1031,15 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, } original_buffer.users = std::move(placed_users); original_buffer.unfinished_user_count = 0; - original_buffer.users.push_back(ItemUse{compressed_item, 0}); + original_buffer.users.push_back(ItemUse{compressed_item, 0, absl::nullopt}); + // We are reallocating the vector containing the buffers potentially, + // invalidating the original_buffer reference, so copy the index that we need + // across NewBuffer calls. + ShapeIndex copied_index = original_buffer.index; Buffer& compressed_buffer = NewBuffer(compressed_item, compressed_item->instruction->shape(), - {ItemUse{uncompressed_item, 0}}, /*live_out=*/false, + copied_index, {ItemUse{uncompressed_item, 0, absl::nullopt}}, + /*live_out=*/false, /*has_indirect_uses=*/false); compressed_item->buffers_used = original_item->buffers_output; compressed_item->buffers_output = {compressed_buffer.id}; @@ -1064,7 +1047,7 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, Buffer& uncompressed_buffer = NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(), - std::move(unplaced_users), /*live_out=*/false, + copied_index, std::move(unplaced_users), /*live_out=*/false, /*has_indirect_uses=*/false); uncompressed_item->buffers_used = {compressed_item->buffers_output[0]}; @@ -1081,7 +1064,7 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, } Status MemoryUsageTracker::AddRematerializedInstruction( - Item* original_item, Item* remat_item, absl::Span bitcasts) { + Item* original_item, Item* remat_item, absl::Span indirect_users) { VLOG(3) << "AddRematerializedInstruction: original_instruction = " << original_item->instruction->name() << ", remat_instruction = " << remat_item->instruction->name(); @@ -1108,19 +1091,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction( std::back_inserter(filtered_users), [&](const ItemUse& iu) { return iu.user == original_item; }); for (ItemUse& u : filtered_users) { - buffer.users.push_back(ItemUse{remat_item, u.operand_number}); - } - } - - for (Item* bitcast : bitcasts) { - CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast); - for (BufferId buffer_id : bitcast->buffers_used) { - Buffer& buffer = buffers_.at(buffer_id); - buffer.unfinished_user_count++; - buffer.users.push_back(ItemUse{bitcast, 0}); + buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index}); } } + const absl::flat_hash_set indirect_users_set(indirect_users.begin(), + indirect_users.end()); // Create a new set of Buffers defined by the new rematerialization // instruction. Update the internal data structures and memory use to account // for them. @@ -1133,7 +1109,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction( if (user.user->placed) { placed_users.push_back(user); } else { - unplaced_users.push_back(user); + // We keep only the indirect users that are in the provided list. + // We consider all the other dead and remove any buffer use they might + // perform and remove it from the buffer user list. + if (!IsSupportedIndirectUser(user.user->instruction) || + indirect_users_set.contains(user.user)) { + unplaced_users.push_back(user); + } else { + CHECK(user.user->buffers_defined.empty()) + << "Buffers defined expected to be empty for use passthrough " + "instructions"; + user.user->buffers_output.clear(); + user.user->buffers_used.clear(); + } } } old_buffer.users = std::move(placed_users); @@ -1146,10 +1134,68 @@ Status MemoryUsageTracker::AddRematerializedInstruction( RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users)); remat_item->buffers_defined.push_back(new_buffer.id); + auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id]( + BufferIdList& to_update) { + std::replace(to_update.begin(), to_update.end(), old_buffer_id, + new_buffer_id); + }; + // Update users with the id of the new buffer. for (ItemUse& user : new_buffer.users) { - BufferIdList& buffers_used = user.user->buffers_used; - std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, - new_buffer.id); + update_buffers(user.user->buffers_used); + update_buffers(user.user->buffers_output); + } + } + + // Update the indirect users with the id of the new buffers. + for (Item* indirect_user : indirect_users) { + // Source of the buffers that are gonna be passthrough. + const Item* source_item = + instruction_list_.GetItem(indirect_user->instruction->operand(0)); + switch (indirect_user->instruction->opcode()) { + case HloOpcode::kBitcast: { + // If the source is another indirect user then copy the output + // in the used and output lists of the bitcast as they don't define any + // buffer. + if (IsSupportedIndirectUser(source_item->instruction)) { + indirect_user->buffers_used = source_item->buffers_output; + indirect_user->buffers_output = source_item->buffers_output; + } else { + // If it's a real instruction producing a buffer then copy the defined + // buffers into used and output. + indirect_user->buffers_used = source_item->buffers_defined; + indirect_user->buffers_output = source_item->buffers_defined; + } + break; + } + case HloOpcode::kGetTupleElement: { + // GTEs just use the tuple buffer and output the buffer they actually + // extract from the tuple. + const HloGetTupleElementInstruction* gte = + Cast(indirect_user->instruction); + for (BufferId buffer_id : source_item->buffers_defined) { + const Buffer& def_buffer = buffers_.at(buffer_id); + if (def_buffer.index == ShapeIndex{gte->tuple_index()}) { + indirect_user->buffers_output.push_back(buffer_id); + } + // This is the tuple buffer. + if (def_buffer.index.empty()) { + indirect_user->buffers_used.push_back(buffer_id); + } + } + break; + } + default: { + LOG(FATAL) << "Unsupported indirect instruction with opcode " + << HloOpcodeString(indirect_user->instruction->opcode()); + break; + } + } + // Fixup buffer users for the indirect instructions. For GTEs is only the + // tuple buffer, while for bitcast is the buffer they pass through. + for (BufferId buffer_id : indirect_user->buffers_used) { + Buffer& buffer = buffers_.at(buffer_id); + buffer.unfinished_user_count++; + buffer.users.push_back(ItemUse{indirect_user, 0, absl::nullopt}); } } @@ -1414,6 +1460,10 @@ MemoryUsageTracker::PickRematerializationCandidates( // break out of this loop. Move on to the next start_item. break; } + VLOG(5) << "Block contains:"; + for (auto* hlo : block) { + VLOG(5) << hlo->instruction->name(); + } const int64 memory_reduced = MemoryReducedIfRematerialized(block); if (memory_reduced > 0) { @@ -1509,21 +1559,33 @@ StatusOr RematerializeInstructions( Item* remat_item = instruction_list->CreateItem(remat); // Replace each remaining use of 'best' with the rematerialization. - absl::InlinedVector bitcasts; + absl::InlinedVector indirect_users; + absl::flat_hash_map gte_cache; for (auto& user : memory_tracker->GetItemUses(best_item)) { if (!memory_tracker->IsPlaced(user.user->instruction)) { VLOG(2) << " Replacing use of " << best->name() << " in " << user.user->instruction->name() << " with " << remat->name(); const int64 op_idx = user.operand_number; - auto* remat_use = remat; + HloInstruction* remat_use = remat; + if (user.index) { + auto cached_gte = gte_cache.find(*user.index); + if (cached_gte == gte_cache.end()) { + remat_use = computation->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(remat_use->shape(), + *user.index), + remat_use, *user.index)); + indirect_users.push_back(instruction_list->CreateItem(remat_use)); + gte_cache[*user.index] = remat_use; + } else { + remat_use = cached_gte->second; + } + } if (user.user->instruction->operand(op_idx)->shape() != - remat->shape()) { - remat_use = computation->AddInstruction(HloInstruction::CreateUnary( - user.user->instruction->operand(op_idx)->shape(), - HloOpcode::kBitcast, remat)); - bitcasts.push_back(instruction_list->CreateItem(remat_use)); - bitcasts.back()->buffers_output = remat_item->buffers_defined; - bitcasts.back()->buffers_used = remat_item->buffers_defined; + remat_use->shape()) { + remat_use = computation->AddInstruction(HloInstruction::CreateBitcast( + user.user->instruction->operand(op_idx)->shape(), remat_use)); + indirect_users.push_back(instruction_list->CreateItem(remat_use)); } TF_RETURN_IF_ERROR( user.user->instruction->ReplaceOperandWith(op_idx, remat_use)); @@ -1532,7 +1594,7 @@ StatusOr RematerializeInstructions( // Account for the rematerialization in the memory tracker. TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction( - best_item, remat_item, absl::MakeSpan(bitcasts))); + best_item, remat_item, absl::MakeSpan(indirect_users))); // Insert rematerialized instruction right before the earliest unplaced // use of the instruction *and* the earliest unplaced last use of any @@ -1540,14 +1602,18 @@ StatusOr RematerializeInstructions( // because we don't want to extend the live range of remat's operands as // this could increase memory usage. ItemList place_before; + const absl::flat_hash_set indirect_users_set(indirect_users.begin(), + indirect_users.end()); for (auto user : remat->users()) { - if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) { + if (!indirect_users_set.contains(instruction_list->GetItem(user))) { place_before.push_back(instruction_list->GetItem(user)); } } - for (auto* bitcast : bitcasts) { - for (auto user : bitcast->instruction->users()) { - place_before.push_back(instruction_list->GetItem(user)); + for (auto* indirect_user : indirect_users) { + for (auto user : indirect_user->instruction->users()) { + if (!indirect_users_set.contains(instruction_list->GetItem(user))) { + place_before.push_back(instruction_list->GetItem(user)); + } } } for (auto* operand : remat->operands()) { @@ -1571,14 +1637,14 @@ StatusOr RematerializeInstructions( } instruction_list->InsertBeforeInstructions(remat_item, place_before); - for (auto* bitcast : bitcasts) { + for (auto* bitcast : indirect_users) { instruction_list->InsertBeforeInstructions(bitcast, place_before); } - // Helper function that looks through bitcasts when determining if there - // is an active user for an HloInstruction. + // Helper function that looks through indirect users when determining if + // there is an active user for an HloInstruction. std::function uses_empty = [&](HloInstruction* i) { for (auto* u : i->users()) { - if (u->opcode() != HloOpcode::kBitcast || !uses_empty(u)) { + if (!IsSupportedIndirectUser(u) || !uses_empty(u)) { return false; } } @@ -1599,12 +1665,12 @@ StatusOr RematerializeInstructions( instruction_list->Denylist(remat); } remat_move_instructions->insert(remat); - net_instructions_added += bitcasts.size(); + net_instructions_added += indirect_users.size(); } else { - net_instructions_added += bitcasts.size() + 1; + net_instructions_added += indirect_users.size() + 1; } - for (auto* bitcast : bitcasts) { - instruction_list->Denylist(bitcast->instruction); + for (auto* indirect_user : indirect_users) { + instruction_list->Denylist(indirect_user->instruction); } } VLOG(1) << "Rematerializing instructions [" diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index e1f7346cf5b..db96a0a34ef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -470,11 +470,10 @@ TEST_F(HloRematerializationTest, CopyNotRematerialized) { class IndirectUseTest : public HloRematerializationTest, public ::testing::WithParamInterface {}; -TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { - // Test that an rematerializable instruction is not rematerialized if it has - // an indirect use. Test is parameterized on whether the value has an indirect - // use, and the instruction should be rematerialized iff the value has no - // indirect use. Module: +TEST_P(IndirectUseTest, IndirectUseRematerialized) { + // Test that an rematerializable instruction is rematerialized if it has + // indirect use + // Module: // // Entry computation: // F32[] %param = {...} @@ -492,11 +491,10 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // F32[1024] %slice = slice(%concat) // // The value %bcast is live across the call and rematerialization of %bcast - // across that point would reduce peak memory use by 4KB. However, %bcast is - // used indirectly in the %negate so rematerialization should not happen. + // across that point would reduce peak memory use by 4KB. // - // This test is parameterized on whether the broadcast has an indirect use or - // not. The indirect use is controlled by the index of the GetTupleElement + // This test is parameterized on whether the broadcast has an indirect use + // or not. The indirect use is controlled by the index of the GetTupleElement // instruction. If the element is 0, then the %negate operand aliases %bcast // (ie %bcast is used indirectly by %negate), otherwise the %negate operand // aliases %add_2. @@ -539,17 +537,17 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - // Pick a memory limit some where between 24KB (initial peak memory including - // parameter and output) and 20KB (peak memory possible with + // Pick a memory limit some where between 24KB (initial peak memory + // including parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, module.get())); - // Rematerialization should only occur if the rematerializable instruction has - // no indirect uses. + // Rematerialization should only occur if the rematerializable instruction + // has no indirect uses. if (indirectly_used) { - EXPECT_FALSE(changed); - EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_TRUE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 3); } else { EXPECT_TRUE(changed); EXPECT_EQ(entry_computation->instruction_count(), 9); @@ -633,7 +631,7 @@ ENTRY %entry { %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0) %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float %reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float - %reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2) } @@ -847,6 +845,199 @@ ENTRY %mycomp (param: f32[1]) -> f32[1024] { EXPECT_TRUE(changed); } +TEST_F(HloRematerializationTest, RematTupleShape) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_mul_comp { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + %x = f32[1024]{0} broadcast(f32[] %p0), dimensions={} + %y = f32[1024]{0} broadcast(f32[] %p1), dimensions={} + %add = f32[1024] add(%x, %y) + %mul = f32[1024] multiply(%x, %y) + ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %param.1 = f32[] parameter(1) + %fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop, + calls=%add_mul_comp + %gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0 + %add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1) + %broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={} + %mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1) + %gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1 + ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloComputation* computation = module->entry_computation(); + const HloInstruction* add = computation->root_instruction(); + ASSERT_THAT(add, op::Add(op::Multiply(), op::GetTupleElement(op::Fusion()))); + const HloInstruction* fusion = add->operand(0)->operand(0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/11 * 1024, module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT( + add, op::Add(op::Multiply(), op::GetTupleElement(AllOf( + op::Fusion(), ::testing::Ne(fusion))))); +} + +TEST_F(HloRematerializationTest, RematTupleShapeDoubleUse) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_mul_comp { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + %x = f32[1024]{0} broadcast(f32[] %p0), dimensions={} + %y = f32[1024]{0} broadcast(f32[] %p1), dimensions={} + %add = f32[1024] add(%x, %y) + %mul = f32[1024] multiply(%x, %y) + ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %param.1 = f32[] parameter(1) + %fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop, + calls=%add_mul_comp + %gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0 + %add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1) + %broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={} + %mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1) + %gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1 + %gte.3 = f32[1024]{0} get-tuple-element(%fus), index=0 + %add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2) + ROOT %mul.2 = f32[1024]{0} multiply(f32[1024]{0} %add.2, f32[1024]{0} %gte.3) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloComputation* computation = module->entry_computation(); + const HloInstruction* add = computation->root_instruction(); + ASSERT_THAT(add, op::Multiply(op::Add(op::Multiply(), + op::GetTupleElement(op::Fusion())), + op::GetTupleElement(op::Fusion()))); + const HloInstruction* fusion = add->operand(0)->operand(0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/11 * 1024, module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT( + add, + op::Multiply( + op::Add(op::Multiply(), op::GetTupleElement(AllOf( + op::Fusion(), ::testing::Ne(fusion)))), + op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(fusion))))); + // Check that the rematerialized fusion is the same for both ops. + EXPECT_EQ(add->operand(0)->operand(1)->operand(0), + add->operand(1)->operand(0)); +} + +TEST_F(HloRematerializationTest, RematTupleShapeThroughBitcasts) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_mul_comp { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + %x = f32[1024]{0} broadcast(f32[] %p0), dimensions={} + %y = f32[1024]{0} broadcast(f32[] %p1), dimensions={} + %add = f32[1024] add(%x, %y) + %mul = f32[1024] multiply(%x, %y) + ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %param.1 = f32[] parameter(1) + %fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop, + calls=%add_mul_comp + %gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0 + %add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1) + %broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={} + %mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1) + %gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1 + %bc.1 = f32[1024,1]{0,1} bitcast(%mul) + %bc.2 = f32[1024,1]{0,1} bitcast(%gte.2) + ROOT %add.2 = f32[1024,1]{0,1} add(f32[1024,1]{0,1} %bc.1, + f32[1024,1]{0,1} %bc.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloComputation* computation = module->entry_computation(); + const HloInstruction* add = computation->root_instruction(); + ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()), + op::Bitcast(op::GetTupleElement(op::Fusion())))); + const HloInstruction* fusion = add->operand(0)->operand(0)->operand(0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/11 * 1024, module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()), + op::Bitcast(op::GetTupleElement( + AllOf(op::Fusion(), ::testing::Ne(fusion)))))); +} + +TEST_F(HloRematerializationTest, RematThroughTuple) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_mul_comp { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + %x = f32[1024]{0} broadcast(f32[] %p0), dimensions={} + %y = f32[1024]{0} broadcast(f32[] %p1), dimensions={} + %add = f32[1024] add(%x, %y) + %mul = f32[1024] multiply(%x, %y) + ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %param.1 = f32[] parameter(1) + %fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop, + calls=%add_mul_comp + %gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0 + %gte.3 = f32[1024]{0} get-tuple-element(%fus), index=1 + %add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.3) + %broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={} + %mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1) + %tpl = (f32[1024]{0}, f32[1024]{0}) tuple(%gte.1, %add) + %bc.1 = f32[1024,1]{0,1} bitcast(%mul) + %gte.2 = f32[1024]{0} get-tuple-element(%tpl), index=0 + ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %gte.2, f32[1024]{0} %add) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + const HloComputation* computation = module->entry_computation(); + const HloInstruction* add = computation->root_instruction(); + ASSERT_THAT(add, op::Add(op::GetTupleElement( + op::Tuple(op::GetTupleElement(op::Fusion()), _)), + op::Add())); + const HloInstruction* tuple = add->operand(0)->operand(0); + const HloInstruction* fusion = tuple->operand(0)->operand(0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/11 * 1024, module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT( + add, op::Add(op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(tuple), + ::testing::Ne(fusion))), + op::Add())); +} } // namespace } // namespace xla