[XLA] Extend hlo_rematerialization pass to support rematerialization of tuple producing instrs.

Allow rematerialization of tuple producing instructions by extending the process we use to
rematerialize bitcasts to also handle get-tuple-element'ed buffers that are not nested.
This allows to rematerialize through tuples as well.

PiperOrigin-RevId: 352691189
Change-Id: Ia1a7674c7e32f1c53253cd5b674abce99f87d509
This commit is contained in:
Marcello Maggioni 2021-01-19 17:41:31 -08:00 committed by TensorFlower Gardener
parent 5de40a1c5b
commit f7c7fbd40b
3 changed files with 378 additions and 120 deletions

View File

@ -3757,6 +3757,7 @@ cc_library(
":call_graph",
":flatten_call_graph",
":hlo",
":hlo_casting_utils",
":hlo_dce",
":hlo_memory_scheduler",
":hlo_ordering",

View File

@ -32,9 +32,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -98,6 +100,14 @@ bool CanBeRematerialized(
return rematerializable;
}
// Return if this is an instruction that relays the buffers it uses to its own
// users and if this is one of these instructions we support the
// rematerialization of.
bool IsSupportedIndirectUser(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kBitcast ||
instruction->opcode() == HloOpcode::kGetTupleElement;
}
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
using BufferIdList = absl::InlinedVector<BufferId, 3>;
@ -162,10 +172,13 @@ struct Item {
struct ItemUse {
Item* user;
int64 operand_number;
absl::optional<int64> index;
ItemUse(Item* user, int64 op_num) : user(user), operand_number(op_num) {}
ItemUse(Item* user, int64 op_num, absl::optional<int64> index)
: user(user), operand_number(op_num), index(index) {}
bool operator==(const ItemUse& other) const {
return user == other.user && operand_number == other.operand_number;
return user == other.user && operand_number == other.operand_number &&
index == other.index;
}
};
@ -449,16 +462,22 @@ UsesList GetUsers(const InstructionList& instruction_list,
continue;
}
if (buffer_alias.instruction() != logical_buffer->instruction() &&
buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) {
!IsSupportedIndirectUser(buffer_alias.instruction())) {
*has_indirect_users = true;
}
// A buffer may be used by the instruction via more than one alias. For
// example, a buffer which appears in more than one element of a tuple.
Item* user_item = instruction_list.GetItem(user);
absl::optional<int64> user_index =
logical_buffer->index().size() != 1
? absl::nullopt
: absl::make_optional(logical_buffer->index().back());
for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
if (!absl::c_linear_search(
users, ItemUse{user_item, static_cast<int>(op_idx)})) {
users.push_back(ItemUse{user_item, static_cast<int>(op_idx)});
users,
ItemUse{user_item, static_cast<int>(op_idx), user_index})) {
users.push_back(
ItemUse{user_item, static_cast<int>(op_idx), user_index});
}
}
}
@ -516,10 +535,6 @@ class MemoryUsageTracker {
// each call to BeginInstruction.
Status EndInstruction();
// Returns the number of bytes that the current memory usage will be reduced
// if the given instruction is rematerialized.
int64 MemoryReducedIfRematerialized(Item* item) const;
// Returns the number of bytes that the current memory usage will be reduced
// if the given instruction is compact.
int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
@ -538,7 +553,7 @@ class MemoryUsageTracker {
// been transformed (rematerialization instruction created and connected
// to uses).
Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
absl::Span<Item*> bitcasts);
absl::Span<Item*> indirect_users);
// Selects and returns the best candidate instructions for rematerialization.
// A sequence of candidate instructions of length between min_block_size and
@ -612,6 +627,9 @@ class MemoryUsageTracker {
// buffer aliasing (eg, tuples).
bool has_indirect_uses;
// Position in the tuple this buffer definition lives in.
ShapeIndex index;
// The instructions which use this buffer.
UsesList users;
@ -639,8 +657,8 @@ class MemoryUsageTracker {
UsesList users = GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &has_indirect_uses);
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
logical_buffer->shape(), std::move(users), live_out,
has_indirect_uses);
logical_buffer->shape(), logical_buffer->index(),
std::move(users), live_out, has_indirect_uses);
}
// Create a new buffer representing a rematerialization of given buffer for
@ -654,7 +672,7 @@ class MemoryUsageTracker {
for (ItemUse& use : rematerialized_uses) {
CHECK(!use.user->placed) << use.user->instruction->name();
}
return NewBuffer(remat_item, original_buffer.shape,
return NewBuffer(remat_item, original_buffer.shape, original_buffer.index,
std::move(rematerialized_uses), /*live_out=*/false,
/*has_indirect_uses=*/false);
}
@ -715,7 +733,8 @@ class MemoryUsageTracker {
// Create a new buffer, add it to buffers_, and return a reference.
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
UsesList&& uses, bool live_out, bool has_indirect_uses) {
const ShapeIndex& index, UsesList&& uses, bool live_out,
bool has_indirect_uses) {
int buffer_id = buffers_.size();
auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
absl::flat_hash_set<Item*> users_set;
@ -726,7 +745,7 @@ class MemoryUsageTracker {
};
buffers_.push_back(Buffer{
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
has_indirect_uses, uses, get_num_of_unique_users(uses)});
has_indirect_uses, index, uses, get_num_of_unique_users(uses)});
return buffers_.back();
}
@ -931,51 +950,6 @@ int64 MemoryUsageTracker::MemoryReducedIfCompressed(
return memory_reduced;
}
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
CHECK_NE(in_progress_item_, nullptr);
if (!item->placed || item == in_progress_item_) {
return 0;
}
// TODO(b/37687140): Rematerialization can increase peak memory consumption at
// an earlier point in the program if rematerialization extends the live range
// of the operand of the instruction being rematerialized across the live
// range of the value of instruction being rematerialized. Don't rematerialize
// in this case (ie, return 0 here).
// Compute the amount of memory reduced (if any) by rematerializing
// 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
// be live at this program point, so initially set memory_reduced to the
// size of its defined values.
int64 memory_reduced = 0;
for (BufferId buffer_id : item->buffers_defined) {
// Avoid rematerializing instructions with indirect uses as it is difficult
// to reason about liveness after rematerializing the instruction.
// TODO(b/37714814): Consider rematerializing instructions with indirect
// uses.
if (buffers_.at(buffer_id).has_indirect_uses) {
return 0;
}
if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
memory_reduced += AllocatedSize(buffer_id);
}
}
// Account for any logical buffers whose live range must be extended across
// this program point.
for (BufferId buffer_id : item->buffers_used) {
if (!IsCurrentlyLive(buffer_id)) {
// This logical buffer is used by 'instruction' but is not live at this
// program point. Rematerializing 'instruction' will extend the buffer's
// live range across this program point.
memory_reduced -= AllocatedSize(buffer_id);
}
}
return memory_reduced;
}
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
absl::Span<const Item* const> items) const {
CHECK_NE(in_progress_item_, nullptr);
@ -994,17 +968,21 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
// will no longer be live at this program point, so initially set
// memory_reduced to the size of its defined values.
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_.at(buffer_id);
// Avoid rematerializing instructions with indirect uses as it is
// difficult to reason about liveness after rematerializing the
// instruction.
// Avoid rematerializing instructions with live out buffers.
// Avoid rematerializing buffers that are in nested tuples.
// TODO(mpurohit): Check why live_out buffers are an issue here.
if (buffers_.at(buffer_id).has_indirect_uses ||
buffers_.at(buffer_id).live_out) {
if (buffer.has_indirect_uses || buffer.live_out ||
buffer.index.size() > 1) {
return 0;
}
if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
if (IsInUse(buffer_id)) {
return 0;
}
if (IsCurrentlyLive(buffer_id)) {
memory_reduced += AllocatedSize(buffer_id);
}
}
@ -1053,10 +1031,15 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
}
original_buffer.users = std::move(placed_users);
original_buffer.unfinished_user_count = 0;
original_buffer.users.push_back(ItemUse{compressed_item, 0});
original_buffer.users.push_back(ItemUse{compressed_item, 0, absl::nullopt});
// We are reallocating the vector containing the buffers potentially,
// invalidating the original_buffer reference, so copy the index that we need
// across NewBuffer calls.
ShapeIndex copied_index = original_buffer.index;
Buffer& compressed_buffer =
NewBuffer(compressed_item, compressed_item->instruction->shape(),
{ItemUse{uncompressed_item, 0}}, /*live_out=*/false,
copied_index, {ItemUse{uncompressed_item, 0, absl::nullopt}},
/*live_out=*/false,
/*has_indirect_uses=*/false);
compressed_item->buffers_used = original_item->buffers_output;
compressed_item->buffers_output = {compressed_buffer.id};
@ -1064,7 +1047,7 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
Buffer& uncompressed_buffer =
NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
std::move(unplaced_users), /*live_out=*/false,
copied_index, std::move(unplaced_users), /*live_out=*/false,
/*has_indirect_uses=*/false);
uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
@ -1081,7 +1064,7 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
}
Status MemoryUsageTracker::AddRematerializedInstruction(
Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) {
Item* original_item, Item* remat_item, absl::Span<Item*> indirect_users) {
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
<< original_item->instruction->name()
<< ", remat_instruction = " << remat_item->instruction->name();
@ -1108,19 +1091,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
std::back_inserter(filtered_users),
[&](const ItemUse& iu) { return iu.user == original_item; });
for (ItemUse& u : filtered_users) {
buffer.users.push_back(ItemUse{remat_item, u.operand_number});
}
}
for (Item* bitcast : bitcasts) {
CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast);
for (BufferId buffer_id : bitcast->buffers_used) {
Buffer& buffer = buffers_.at(buffer_id);
buffer.unfinished_user_count++;
buffer.users.push_back(ItemUse{bitcast, 0});
buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index});
}
}
const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
indirect_users.end());
// Create a new set of Buffers defined by the new rematerialization
// instruction. Update the internal data structures and memory use to account
// for them.
@ -1133,7 +1109,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
if (user.user->placed) {
placed_users.push_back(user);
} else {
unplaced_users.push_back(user);
// We keep only the indirect users that are in the provided list.
// We consider all the other dead and remove any buffer use they might
// perform and remove it from the buffer user list.
if (!IsSupportedIndirectUser(user.user->instruction) ||
indirect_users_set.contains(user.user)) {
unplaced_users.push_back(user);
} else {
CHECK(user.user->buffers_defined.empty())
<< "Buffers defined expected to be empty for use passthrough "
"instructions";
user.user->buffers_output.clear();
user.user->buffers_used.clear();
}
}
}
old_buffer.users = std::move(placed_users);
@ -1146,10 +1134,68 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
remat_item->buffers_defined.push_back(new_buffer.id);
auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id](
BufferIdList& to_update) {
std::replace(to_update.begin(), to_update.end(), old_buffer_id,
new_buffer_id);
};
// Update users with the id of the new buffer.
for (ItemUse& user : new_buffer.users) {
BufferIdList& buffers_used = user.user->buffers_used;
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
new_buffer.id);
update_buffers(user.user->buffers_used);
update_buffers(user.user->buffers_output);
}
}
// Update the indirect users with the id of the new buffers.
for (Item* indirect_user : indirect_users) {
// Source of the buffers that are gonna be passthrough.
const Item* source_item =
instruction_list_.GetItem(indirect_user->instruction->operand(0));
switch (indirect_user->instruction->opcode()) {
case HloOpcode::kBitcast: {
// If the source is another indirect user then copy the output
// in the used and output lists of the bitcast as they don't define any
// buffer.
if (IsSupportedIndirectUser(source_item->instruction)) {
indirect_user->buffers_used = source_item->buffers_output;
indirect_user->buffers_output = source_item->buffers_output;
} else {
// If it's a real instruction producing a buffer then copy the defined
// buffers into used and output.
indirect_user->buffers_used = source_item->buffers_defined;
indirect_user->buffers_output = source_item->buffers_defined;
}
break;
}
case HloOpcode::kGetTupleElement: {
// GTEs just use the tuple buffer and output the buffer they actually
// extract from the tuple.
const HloGetTupleElementInstruction* gte =
Cast<HloGetTupleElementInstruction>(indirect_user->instruction);
for (BufferId buffer_id : source_item->buffers_defined) {
const Buffer& def_buffer = buffers_.at(buffer_id);
if (def_buffer.index == ShapeIndex{gte->tuple_index()}) {
indirect_user->buffers_output.push_back(buffer_id);
}
// This is the tuple buffer.
if (def_buffer.index.empty()) {
indirect_user->buffers_used.push_back(buffer_id);
}
}
break;
}
default: {
LOG(FATAL) << "Unsupported indirect instruction with opcode "
<< HloOpcodeString(indirect_user->instruction->opcode());
break;
}
}
// Fixup buffer users for the indirect instructions. For GTEs is only the
// tuple buffer, while for bitcast is the buffer they pass through.
for (BufferId buffer_id : indirect_user->buffers_used) {
Buffer& buffer = buffers_.at(buffer_id);
buffer.unfinished_user_count++;
buffer.users.push_back(ItemUse{indirect_user, 0, absl::nullopt});
}
}
@ -1414,6 +1460,10 @@ MemoryUsageTracker::PickRematerializationCandidates(
// break out of this loop. Move on to the next start_item.
break;
}
VLOG(5) << "Block contains:";
for (auto* hlo : block) {
VLOG(5) << hlo->instruction->name();
}
const int64 memory_reduced = MemoryReducedIfRematerialized(block);
if (memory_reduced > 0) {
@ -1509,21 +1559,33 @@ StatusOr<int64> RematerializeInstructions(
Item* remat_item = instruction_list->CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization.
absl::InlinedVector<Item*, 4> bitcasts;
absl::InlinedVector<Item*, 4> indirect_users;
absl::flat_hash_map<int64, HloInstruction*> gte_cache;
for (auto& user : memory_tracker->GetItemUses(best_item)) {
if (!memory_tracker->IsPlaced(user.user->instruction)) {
VLOG(2) << " Replacing use of " << best->name() << " in "
<< user.user->instruction->name() << " with " << remat->name();
const int64 op_idx = user.operand_number;
auto* remat_use = remat;
HloInstruction* remat_use = remat;
if (user.index) {
auto cached_gte = gte_cache.find(*user.index);
if (cached_gte == gte_cache.end()) {
remat_use = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(remat_use->shape(),
*user.index),
remat_use, *user.index));
indirect_users.push_back(instruction_list->CreateItem(remat_use));
gte_cache[*user.index] = remat_use;
} else {
remat_use = cached_gte->second;
}
}
if (user.user->instruction->operand(op_idx)->shape() !=
remat->shape()) {
remat_use = computation->AddInstruction(HloInstruction::CreateUnary(
user.user->instruction->operand(op_idx)->shape(),
HloOpcode::kBitcast, remat));
bitcasts.push_back(instruction_list->CreateItem(remat_use));
bitcasts.back()->buffers_output = remat_item->buffers_defined;
bitcasts.back()->buffers_used = remat_item->buffers_defined;
remat_use->shape()) {
remat_use = computation->AddInstruction(HloInstruction::CreateBitcast(
user.user->instruction->operand(op_idx)->shape(), remat_use));
indirect_users.push_back(instruction_list->CreateItem(remat_use));
}
TF_RETURN_IF_ERROR(
user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
@ -1532,7 +1594,7 @@ StatusOr<int64> RematerializeInstructions(
// Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
best_item, remat_item, absl::MakeSpan(bitcasts)));
best_item, remat_item, absl::MakeSpan(indirect_users)));
// Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any
@ -1540,14 +1602,18 @@ StatusOr<int64> RematerializeInstructions(
// because we don't want to extend the live range of remat's operands as
// this could increase memory usage.
ItemList place_before;
const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
indirect_users.end());
for (auto user : remat->users()) {
if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) {
if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
place_before.push_back(instruction_list->GetItem(user));
}
}
for (auto* bitcast : bitcasts) {
for (auto user : bitcast->instruction->users()) {
place_before.push_back(instruction_list->GetItem(user));
for (auto* indirect_user : indirect_users) {
for (auto user : indirect_user->instruction->users()) {
if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
place_before.push_back(instruction_list->GetItem(user));
}
}
}
for (auto* operand : remat->operands()) {
@ -1571,14 +1637,14 @@ StatusOr<int64> RematerializeInstructions(
}
instruction_list->InsertBeforeInstructions(remat_item, place_before);
for (auto* bitcast : bitcasts) {
for (auto* bitcast : indirect_users) {
instruction_list->InsertBeforeInstructions(bitcast, place_before);
}
// Helper function that looks through bitcasts when determining if there
// is an active user for an HloInstruction.
// Helper function that looks through indirect users when determining if
// there is an active user for an HloInstruction.
std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
for (auto* u : i->users()) {
if (u->opcode() != HloOpcode::kBitcast || !uses_empty(u)) {
if (!IsSupportedIndirectUser(u) || !uses_empty(u)) {
return false;
}
}
@ -1599,12 +1665,12 @@ StatusOr<int64> RematerializeInstructions(
instruction_list->Denylist(remat);
}
remat_move_instructions->insert(remat);
net_instructions_added += bitcasts.size();
net_instructions_added += indirect_users.size();
} else {
net_instructions_added += bitcasts.size() + 1;
net_instructions_added += indirect_users.size() + 1;
}
for (auto* bitcast : bitcasts) {
instruction_list->Denylist(bitcast->instruction);
for (auto* indirect_user : indirect_users) {
instruction_list->Denylist(indirect_user->instruction);
}
}
VLOG(1) << "Rematerializing instructions ["

View File

@ -470,11 +470,10 @@ TEST_F(HloRematerializationTest, CopyNotRematerialized) {
class IndirectUseTest : public HloRematerializationTest,
public ::testing::WithParamInterface<bool> {};
TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// Test that an rematerializable instruction is not rematerialized if it has
// an indirect use. Test is parameterized on whether the value has an indirect
// use, and the instruction should be rematerialized iff the value has no
// indirect use. Module:
TEST_P(IndirectUseTest, IndirectUseRematerialized) {
// Test that an rematerializable instruction is rematerialized if it has
// indirect use
// Module:
//
// Entry computation:
// F32[] %param = {...}
@ -492,11 +491,10 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// F32[1024] %slice = slice(%concat)
//
// The value %bcast is live across the call and rematerialization of %bcast
// across that point would reduce peak memory use by 4KB. However, %bcast is
// used indirectly in the %negate so rematerialization should not happen.
// across that point would reduce peak memory use by 4KB.
//
// This test is parameterized on whether the broadcast has an indirect use or
// not. The indirect use is controlled by the index of the GetTupleElement
// This test is parameterized on whether the broadcast has an indirect use
// or not. The indirect use is controlled by the index of the GetTupleElement
// instruction. If the element is 0, then the %negate operand aliases %bcast
// (ie %bcast is used indirectly by %negate), otherwise the %negate operand
// aliases %add_2.
@ -539,17 +537,17 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// Pick a memory limit some where between 24KB (initial peak memory
// including parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/22 * 1024, module.get()));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
// Rematerialization should only occur if the rematerializable instruction
// has no indirect uses.
if (indirectly_used) {
EXPECT_FALSE(changed);
EXPECT_EQ(entry_computation->instruction_count(), 8);
EXPECT_TRUE(changed);
EXPECT_EQ(entry_computation->instruction_count(), 3);
} else {
EXPECT_TRUE(changed);
EXPECT_EQ(entry_computation->instruction_count(), 9);
@ -633,7 +631,7 @@ ENTRY %entry {
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2)
}
@ -847,6 +845,199 @@ ENTRY %mycomp (param: f32[1]) -> f32[1024] {
EXPECT_TRUE(changed);
}
TEST_F(HloRematerializationTest, RematTupleShape) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Add(op::Multiply(), op::GetTupleElement(op::Fusion())));
const HloInstruction* fusion = add->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(
add, op::Add(op::Multiply(), op::GetTupleElement(AllOf(
op::Fusion(), ::testing::Ne(fusion)))));
}
TEST_F(HloRematerializationTest, RematTupleShapeDoubleUse) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
%gte.3 = f32[1024]{0} get-tuple-element(%fus), index=0
%add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2)
ROOT %mul.2 = f32[1024]{0} multiply(f32[1024]{0} %add.2, f32[1024]{0} %gte.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Multiply(op::Add(op::Multiply(),
op::GetTupleElement(op::Fusion())),
op::GetTupleElement(op::Fusion())));
const HloInstruction* fusion = add->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(
add,
op::Multiply(
op::Add(op::Multiply(), op::GetTupleElement(AllOf(
op::Fusion(), ::testing::Ne(fusion)))),
op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(fusion)))));
// Check that the rematerialized fusion is the same for both ops.
EXPECT_EQ(add->operand(0)->operand(1)->operand(0),
add->operand(1)->operand(0));
}
TEST_F(HloRematerializationTest, RematTupleShapeThroughBitcasts) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
%bc.1 = f32[1024,1]{0,1} bitcast(%mul)
%bc.2 = f32[1024,1]{0,1} bitcast(%gte.2)
ROOT %add.2 = f32[1024,1]{0,1} add(f32[1024,1]{0,1} %bc.1,
f32[1024,1]{0,1} %bc.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()),
op::Bitcast(op::GetTupleElement(op::Fusion()))));
const HloInstruction* fusion = add->operand(0)->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()),
op::Bitcast(op::GetTupleElement(
AllOf(op::Fusion(), ::testing::Ne(fusion))))));
}
TEST_F(HloRematerializationTest, RematThroughTuple) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_mul_comp {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
%add = f32[1024] add(%x, %y)
%mul = f32[1024] multiply(%x, %y)
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%param.1 = f32[] parameter(1)
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
calls=%add_mul_comp
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
%gte.3 = f32[1024]{0} get-tuple-element(%fus), index=1
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.3)
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
%tpl = (f32[1024]{0}, f32[1024]{0}) tuple(%gte.1, %add)
%bc.1 = f32[1024,1]{0,1} bitcast(%mul)
%gte.2 = f32[1024]{0} get-tuple-element(%tpl), index=0
ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %gte.2, f32[1024]{0} %add)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloComputation* computation = module->entry_computation();
const HloInstruction* add = computation->root_instruction();
ASSERT_THAT(add, op::Add(op::GetTupleElement(
op::Tuple(op::GetTupleElement(op::Fusion()), _)),
op::Add()));
const HloInstruction* tuple = add->operand(0)->operand(0);
const HloInstruction* fusion = tuple->operand(0)->operand(0);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/11 * 1024, module.get()));
EXPECT_TRUE(changed);
ASSERT_THAT(
add, op::Add(op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(tuple),
::testing::Ne(fusion))),
op::Add()));
}
} // namespace
} // namespace xla