[XLA] Extend hlo_rematerialization pass to support rematerialization of tuple producing instrs.
Allow rematerialization of tuple producing instructions by extending the process we use to rematerialize bitcasts to also handle get-tuple-element'ed buffers that are not nested. This allows to rematerialize through tuples as well. PiperOrigin-RevId: 352691189 Change-Id: Ia1a7674c7e32f1c53253cd5b674abce99f87d509
This commit is contained in:
parent
5de40a1c5b
commit
f7c7fbd40b
@ -3757,6 +3757,7 @@ cc_library(
|
||||
":call_graph",
|
||||
":flatten_call_graph",
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_dce",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
|
@ -32,9 +32,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -98,6 +100,14 @@ bool CanBeRematerialized(
|
||||
return rematerializable;
|
||||
}
|
||||
|
||||
// Return if this is an instruction that relays the buffers it uses to its own
|
||||
// users and if this is one of these instructions we support the
|
||||
// rematerialization of.
|
||||
bool IsSupportedIndirectUser(const HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kBitcast ||
|
||||
instruction->opcode() == HloOpcode::kGetTupleElement;
|
||||
}
|
||||
|
||||
// Type holding a unique identifier for each Buffer object.
|
||||
using BufferId = int64;
|
||||
using BufferIdList = absl::InlinedVector<BufferId, 3>;
|
||||
@ -162,10 +172,13 @@ struct Item {
|
||||
struct ItemUse {
|
||||
Item* user;
|
||||
int64 operand_number;
|
||||
absl::optional<int64> index;
|
||||
|
||||
ItemUse(Item* user, int64 op_num) : user(user), operand_number(op_num) {}
|
||||
ItemUse(Item* user, int64 op_num, absl::optional<int64> index)
|
||||
: user(user), operand_number(op_num), index(index) {}
|
||||
bool operator==(const ItemUse& other) const {
|
||||
return user == other.user && operand_number == other.operand_number;
|
||||
return user == other.user && operand_number == other.operand_number &&
|
||||
index == other.index;
|
||||
}
|
||||
};
|
||||
|
||||
@ -449,16 +462,22 @@ UsesList GetUsers(const InstructionList& instruction_list,
|
||||
continue;
|
||||
}
|
||||
if (buffer_alias.instruction() != logical_buffer->instruction() &&
|
||||
buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) {
|
||||
!IsSupportedIndirectUser(buffer_alias.instruction())) {
|
||||
*has_indirect_users = true;
|
||||
}
|
||||
// A buffer may be used by the instruction via more than one alias. For
|
||||
// example, a buffer which appears in more than one element of a tuple.
|
||||
Item* user_item = instruction_list.GetItem(user);
|
||||
absl::optional<int64> user_index =
|
||||
logical_buffer->index().size() != 1
|
||||
? absl::nullopt
|
||||
: absl::make_optional(logical_buffer->index().back());
|
||||
for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
|
||||
if (!absl::c_linear_search(
|
||||
users, ItemUse{user_item, static_cast<int>(op_idx)})) {
|
||||
users.push_back(ItemUse{user_item, static_cast<int>(op_idx)});
|
||||
users,
|
||||
ItemUse{user_item, static_cast<int>(op_idx), user_index})) {
|
||||
users.push_back(
|
||||
ItemUse{user_item, static_cast<int>(op_idx), user_index});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -516,10 +535,6 @@ class MemoryUsageTracker {
|
||||
// each call to BeginInstruction.
|
||||
Status EndInstruction();
|
||||
|
||||
// Returns the number of bytes that the current memory usage will be reduced
|
||||
// if the given instruction is rematerialized.
|
||||
int64 MemoryReducedIfRematerialized(Item* item) const;
|
||||
|
||||
// Returns the number of bytes that the current memory usage will be reduced
|
||||
// if the given instruction is compact.
|
||||
int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
|
||||
@ -538,7 +553,7 @@ class MemoryUsageTracker {
|
||||
// been transformed (rematerialization instruction created and connected
|
||||
// to uses).
|
||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
|
||||
absl::Span<Item*> bitcasts);
|
||||
absl::Span<Item*> indirect_users);
|
||||
|
||||
// Selects and returns the best candidate instructions for rematerialization.
|
||||
// A sequence of candidate instructions of length between min_block_size and
|
||||
@ -612,6 +627,9 @@ class MemoryUsageTracker {
|
||||
// buffer aliasing (eg, tuples).
|
||||
bool has_indirect_uses;
|
||||
|
||||
// Position in the tuple this buffer definition lives in.
|
||||
ShapeIndex index;
|
||||
|
||||
// The instructions which use this buffer.
|
||||
UsesList users;
|
||||
|
||||
@ -639,8 +657,8 @@ class MemoryUsageTracker {
|
||||
UsesList users = GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &has_indirect_uses);
|
||||
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
||||
logical_buffer->shape(), std::move(users), live_out,
|
||||
has_indirect_uses);
|
||||
logical_buffer->shape(), logical_buffer->index(),
|
||||
std::move(users), live_out, has_indirect_uses);
|
||||
}
|
||||
|
||||
// Create a new buffer representing a rematerialization of given buffer for
|
||||
@ -654,7 +672,7 @@ class MemoryUsageTracker {
|
||||
for (ItemUse& use : rematerialized_uses) {
|
||||
CHECK(!use.user->placed) << use.user->instruction->name();
|
||||
}
|
||||
return NewBuffer(remat_item, original_buffer.shape,
|
||||
return NewBuffer(remat_item, original_buffer.shape, original_buffer.index,
|
||||
std::move(rematerialized_uses), /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
}
|
||||
@ -715,7 +733,8 @@ class MemoryUsageTracker {
|
||||
|
||||
// Create a new buffer, add it to buffers_, and return a reference.
|
||||
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
|
||||
UsesList&& uses, bool live_out, bool has_indirect_uses) {
|
||||
const ShapeIndex& index, UsesList&& uses, bool live_out,
|
||||
bool has_indirect_uses) {
|
||||
int buffer_id = buffers_.size();
|
||||
auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
|
||||
absl::flat_hash_set<Item*> users_set;
|
||||
@ -726,7 +745,7 @@ class MemoryUsageTracker {
|
||||
};
|
||||
buffers_.push_back(Buffer{
|
||||
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
|
||||
has_indirect_uses, uses, get_num_of_unique_users(uses)});
|
||||
has_indirect_uses, index, uses, get_num_of_unique_users(uses)});
|
||||
return buffers_.back();
|
||||
}
|
||||
|
||||
@ -931,51 +950,6 @@ int64 MemoryUsageTracker::MemoryReducedIfCompressed(
|
||||
return memory_reduced;
|
||||
}
|
||||
|
||||
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
|
||||
CHECK_NE(in_progress_item_, nullptr);
|
||||
if (!item->placed || item == in_progress_item_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TODO(b/37687140): Rematerialization can increase peak memory consumption at
|
||||
// an earlier point in the program if rematerialization extends the live range
|
||||
// of the operand of the instruction being rematerialized across the live
|
||||
// range of the value of instruction being rematerialized. Don't rematerialize
|
||||
// in this case (ie, return 0 here).
|
||||
|
||||
// Compute the amount of memory reduced (if any) by rematerializing
|
||||
// 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
|
||||
// be live at this program point, so initially set memory_reduced to the
|
||||
// size of its defined values.
|
||||
int64 memory_reduced = 0;
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
// Avoid rematerializing instructions with indirect uses as it is difficult
|
||||
// to reason about liveness after rematerializing the instruction.
|
||||
// TODO(b/37714814): Consider rematerializing instructions with indirect
|
||||
// uses.
|
||||
if (buffers_.at(buffer_id).has_indirect_uses) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
|
||||
memory_reduced += AllocatedSize(buffer_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Account for any logical buffers whose live range must be extended across
|
||||
// this program point.
|
||||
for (BufferId buffer_id : item->buffers_used) {
|
||||
if (!IsCurrentlyLive(buffer_id)) {
|
||||
// This logical buffer is used by 'instruction' but is not live at this
|
||||
// program point. Rematerializing 'instruction' will extend the buffer's
|
||||
// live range across this program point.
|
||||
memory_reduced -= AllocatedSize(buffer_id);
|
||||
}
|
||||
}
|
||||
|
||||
return memory_reduced;
|
||||
}
|
||||
|
||||
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
absl::Span<const Item* const> items) const {
|
||||
CHECK_NE(in_progress_item_, nullptr);
|
||||
@ -994,17 +968,21 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
// will no longer be live at this program point, so initially set
|
||||
// memory_reduced to the size of its defined values.
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
// Avoid rematerializing instructions with indirect uses as it is
|
||||
// difficult to reason about liveness after rematerializing the
|
||||
// instruction.
|
||||
// Avoid rematerializing instructions with live out buffers.
|
||||
// Avoid rematerializing buffers that are in nested tuples.
|
||||
// TODO(mpurohit): Check why live_out buffers are an issue here.
|
||||
if (buffers_.at(buffer_id).has_indirect_uses ||
|
||||
buffers_.at(buffer_id).live_out) {
|
||||
if (buffer.has_indirect_uses || buffer.live_out ||
|
||||
buffer.index.size() > 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
|
||||
if (IsInUse(buffer_id)) {
|
||||
return 0;
|
||||
}
|
||||
if (IsCurrentlyLive(buffer_id)) {
|
||||
memory_reduced += AllocatedSize(buffer_id);
|
||||
}
|
||||
}
|
||||
@ -1053,10 +1031,15 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
}
|
||||
original_buffer.users = std::move(placed_users);
|
||||
original_buffer.unfinished_user_count = 0;
|
||||
original_buffer.users.push_back(ItemUse{compressed_item, 0});
|
||||
original_buffer.users.push_back(ItemUse{compressed_item, 0, absl::nullopt});
|
||||
// We are reallocating the vector containing the buffers potentially,
|
||||
// invalidating the original_buffer reference, so copy the index that we need
|
||||
// across NewBuffer calls.
|
||||
ShapeIndex copied_index = original_buffer.index;
|
||||
Buffer& compressed_buffer =
|
||||
NewBuffer(compressed_item, compressed_item->instruction->shape(),
|
||||
{ItemUse{uncompressed_item, 0}}, /*live_out=*/false,
|
||||
copied_index, {ItemUse{uncompressed_item, 0, absl::nullopt}},
|
||||
/*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
compressed_item->buffers_used = original_item->buffers_output;
|
||||
compressed_item->buffers_output = {compressed_buffer.id};
|
||||
@ -1064,7 +1047,7 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
|
||||
Buffer& uncompressed_buffer =
|
||||
NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
|
||||
std::move(unplaced_users), /*live_out=*/false,
|
||||
copied_index, std::move(unplaced_users), /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
|
||||
uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
|
||||
@ -1081,7 +1064,7 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) {
|
||||
Item* original_item, Item* remat_item, absl::Span<Item*> indirect_users) {
|
||||
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
||||
<< original_item->instruction->name()
|
||||
<< ", remat_instruction = " << remat_item->instruction->name();
|
||||
@ -1108,19 +1091,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
std::back_inserter(filtered_users),
|
||||
[&](const ItemUse& iu) { return iu.user == original_item; });
|
||||
for (ItemUse& u : filtered_users) {
|
||||
buffer.users.push_back(ItemUse{remat_item, u.operand_number});
|
||||
}
|
||||
}
|
||||
|
||||
for (Item* bitcast : bitcasts) {
|
||||
CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast);
|
||||
for (BufferId buffer_id : bitcast->buffers_used) {
|
||||
Buffer& buffer = buffers_.at(buffer_id);
|
||||
buffer.unfinished_user_count++;
|
||||
buffer.users.push_back(ItemUse{bitcast, 0});
|
||||
buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index});
|
||||
}
|
||||
}
|
||||
|
||||
const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
|
||||
indirect_users.end());
|
||||
// Create a new set of Buffers defined by the new rematerialization
|
||||
// instruction. Update the internal data structures and memory use to account
|
||||
// for them.
|
||||
@ -1133,7 +1109,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
if (user.user->placed) {
|
||||
placed_users.push_back(user);
|
||||
} else {
|
||||
unplaced_users.push_back(user);
|
||||
// We keep only the indirect users that are in the provided list.
|
||||
// We consider all the other dead and remove any buffer use they might
|
||||
// perform and remove it from the buffer user list.
|
||||
if (!IsSupportedIndirectUser(user.user->instruction) ||
|
||||
indirect_users_set.contains(user.user)) {
|
||||
unplaced_users.push_back(user);
|
||||
} else {
|
||||
CHECK(user.user->buffers_defined.empty())
|
||||
<< "Buffers defined expected to be empty for use passthrough "
|
||||
"instructions";
|
||||
user.user->buffers_output.clear();
|
||||
user.user->buffers_used.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
old_buffer.users = std::move(placed_users);
|
||||
@ -1146,10 +1134,68 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
|
||||
|
||||
remat_item->buffers_defined.push_back(new_buffer.id);
|
||||
auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id](
|
||||
BufferIdList& to_update) {
|
||||
std::replace(to_update.begin(), to_update.end(), old_buffer_id,
|
||||
new_buffer_id);
|
||||
};
|
||||
// Update users with the id of the new buffer.
|
||||
for (ItemUse& user : new_buffer.users) {
|
||||
BufferIdList& buffers_used = user.user->buffers_used;
|
||||
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
||||
new_buffer.id);
|
||||
update_buffers(user.user->buffers_used);
|
||||
update_buffers(user.user->buffers_output);
|
||||
}
|
||||
}
|
||||
|
||||
// Update the indirect users with the id of the new buffers.
|
||||
for (Item* indirect_user : indirect_users) {
|
||||
// Source of the buffers that are gonna be passthrough.
|
||||
const Item* source_item =
|
||||
instruction_list_.GetItem(indirect_user->instruction->operand(0));
|
||||
switch (indirect_user->instruction->opcode()) {
|
||||
case HloOpcode::kBitcast: {
|
||||
// If the source is another indirect user then copy the output
|
||||
// in the used and output lists of the bitcast as they don't define any
|
||||
// buffer.
|
||||
if (IsSupportedIndirectUser(source_item->instruction)) {
|
||||
indirect_user->buffers_used = source_item->buffers_output;
|
||||
indirect_user->buffers_output = source_item->buffers_output;
|
||||
} else {
|
||||
// If it's a real instruction producing a buffer then copy the defined
|
||||
// buffers into used and output.
|
||||
indirect_user->buffers_used = source_item->buffers_defined;
|
||||
indirect_user->buffers_output = source_item->buffers_defined;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
// GTEs just use the tuple buffer and output the buffer they actually
|
||||
// extract from the tuple.
|
||||
const HloGetTupleElementInstruction* gte =
|
||||
Cast<HloGetTupleElementInstruction>(indirect_user->instruction);
|
||||
for (BufferId buffer_id : source_item->buffers_defined) {
|
||||
const Buffer& def_buffer = buffers_.at(buffer_id);
|
||||
if (def_buffer.index == ShapeIndex{gte->tuple_index()}) {
|
||||
indirect_user->buffers_output.push_back(buffer_id);
|
||||
}
|
||||
// This is the tuple buffer.
|
||||
if (def_buffer.index.empty()) {
|
||||
indirect_user->buffers_used.push_back(buffer_id);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unsupported indirect instruction with opcode "
|
||||
<< HloOpcodeString(indirect_user->instruction->opcode());
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Fixup buffer users for the indirect instructions. For GTEs is only the
|
||||
// tuple buffer, while for bitcast is the buffer they pass through.
|
||||
for (BufferId buffer_id : indirect_user->buffers_used) {
|
||||
Buffer& buffer = buffers_.at(buffer_id);
|
||||
buffer.unfinished_user_count++;
|
||||
buffer.users.push_back(ItemUse{indirect_user, 0, absl::nullopt});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1414,6 +1460,10 @@ MemoryUsageTracker::PickRematerializationCandidates(
|
||||
// break out of this loop. Move on to the next start_item.
|
||||
break;
|
||||
}
|
||||
VLOG(5) << "Block contains:";
|
||||
for (auto* hlo : block) {
|
||||
VLOG(5) << hlo->instruction->name();
|
||||
}
|
||||
const int64 memory_reduced = MemoryReducedIfRematerialized(block);
|
||||
|
||||
if (memory_reduced > 0) {
|
||||
@ -1509,21 +1559,33 @@ StatusOr<int64> RematerializeInstructions(
|
||||
Item* remat_item = instruction_list->CreateItem(remat);
|
||||
|
||||
// Replace each remaining use of 'best' with the rematerialization.
|
||||
absl::InlinedVector<Item*, 4> bitcasts;
|
||||
absl::InlinedVector<Item*, 4> indirect_users;
|
||||
absl::flat_hash_map<int64, HloInstruction*> gte_cache;
|
||||
for (auto& user : memory_tracker->GetItemUses(best_item)) {
|
||||
if (!memory_tracker->IsPlaced(user.user->instruction)) {
|
||||
VLOG(2) << " Replacing use of " << best->name() << " in "
|
||||
<< user.user->instruction->name() << " with " << remat->name();
|
||||
const int64 op_idx = user.operand_number;
|
||||
auto* remat_use = remat;
|
||||
HloInstruction* remat_use = remat;
|
||||
if (user.index) {
|
||||
auto cached_gte = gte_cache.find(*user.index);
|
||||
if (cached_gte == gte_cache.end()) {
|
||||
remat_use = computation->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::GetTupleElementShape(remat_use->shape(),
|
||||
*user.index),
|
||||
remat_use, *user.index));
|
||||
indirect_users.push_back(instruction_list->CreateItem(remat_use));
|
||||
gte_cache[*user.index] = remat_use;
|
||||
} else {
|
||||
remat_use = cached_gte->second;
|
||||
}
|
||||
}
|
||||
if (user.user->instruction->operand(op_idx)->shape() !=
|
||||
remat->shape()) {
|
||||
remat_use = computation->AddInstruction(HloInstruction::CreateUnary(
|
||||
user.user->instruction->operand(op_idx)->shape(),
|
||||
HloOpcode::kBitcast, remat));
|
||||
bitcasts.push_back(instruction_list->CreateItem(remat_use));
|
||||
bitcasts.back()->buffers_output = remat_item->buffers_defined;
|
||||
bitcasts.back()->buffers_used = remat_item->buffers_defined;
|
||||
remat_use->shape()) {
|
||||
remat_use = computation->AddInstruction(HloInstruction::CreateBitcast(
|
||||
user.user->instruction->operand(op_idx)->shape(), remat_use));
|
||||
indirect_users.push_back(instruction_list->CreateItem(remat_use));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
|
||||
@ -1532,7 +1594,7 @@ StatusOr<int64> RematerializeInstructions(
|
||||
|
||||
// Account for the rematerialization in the memory tracker.
|
||||
TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
|
||||
best_item, remat_item, absl::MakeSpan(bitcasts)));
|
||||
best_item, remat_item, absl::MakeSpan(indirect_users)));
|
||||
|
||||
// Insert rematerialized instruction right before the earliest unplaced
|
||||
// use of the instruction *and* the earliest unplaced last use of any
|
||||
@ -1540,14 +1602,18 @@ StatusOr<int64> RematerializeInstructions(
|
||||
// because we don't want to extend the live range of remat's operands as
|
||||
// this could increase memory usage.
|
||||
ItemList place_before;
|
||||
const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
|
||||
indirect_users.end());
|
||||
for (auto user : remat->users()) {
|
||||
if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) {
|
||||
if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
}
|
||||
for (auto* bitcast : bitcasts) {
|
||||
for (auto user : bitcast->instruction->users()) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
for (auto* indirect_user : indirect_users) {
|
||||
for (auto user : indirect_user->instruction->users()) {
|
||||
if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto* operand : remat->operands()) {
|
||||
@ -1571,14 +1637,14 @@ StatusOr<int64> RematerializeInstructions(
|
||||
}
|
||||
instruction_list->InsertBeforeInstructions(remat_item, place_before);
|
||||
|
||||
for (auto* bitcast : bitcasts) {
|
||||
for (auto* bitcast : indirect_users) {
|
||||
instruction_list->InsertBeforeInstructions(bitcast, place_before);
|
||||
}
|
||||
// Helper function that looks through bitcasts when determining if there
|
||||
// is an active user for an HloInstruction.
|
||||
// Helper function that looks through indirect users when determining if
|
||||
// there is an active user for an HloInstruction.
|
||||
std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
|
||||
for (auto* u : i->users()) {
|
||||
if (u->opcode() != HloOpcode::kBitcast || !uses_empty(u)) {
|
||||
if (!IsSupportedIndirectUser(u) || !uses_empty(u)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -1599,12 +1665,12 @@ StatusOr<int64> RematerializeInstructions(
|
||||
instruction_list->Denylist(remat);
|
||||
}
|
||||
remat_move_instructions->insert(remat);
|
||||
net_instructions_added += bitcasts.size();
|
||||
net_instructions_added += indirect_users.size();
|
||||
} else {
|
||||
net_instructions_added += bitcasts.size() + 1;
|
||||
net_instructions_added += indirect_users.size() + 1;
|
||||
}
|
||||
for (auto* bitcast : bitcasts) {
|
||||
instruction_list->Denylist(bitcast->instruction);
|
||||
for (auto* indirect_user : indirect_users) {
|
||||
instruction_list->Denylist(indirect_user->instruction);
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Rematerializing instructions ["
|
||||
|
@ -470,11 +470,10 @@ TEST_F(HloRematerializationTest, CopyNotRematerialized) {
|
||||
class IndirectUseTest : public HloRematerializationTest,
|
||||
public ::testing::WithParamInterface<bool> {};
|
||||
|
||||
TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
||||
// Test that an rematerializable instruction is not rematerialized if it has
|
||||
// an indirect use. Test is parameterized on whether the value has an indirect
|
||||
// use, and the instruction should be rematerialized iff the value has no
|
||||
// indirect use. Module:
|
||||
TEST_P(IndirectUseTest, IndirectUseRematerialized) {
|
||||
// Test that an rematerializable instruction is rematerialized if it has
|
||||
// indirect use
|
||||
// Module:
|
||||
//
|
||||
// Entry computation:
|
||||
// F32[] %param = {...}
|
||||
@ -492,11 +491,10 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
||||
// F32[1024] %slice = slice(%concat)
|
||||
//
|
||||
// The value %bcast is live across the call and rematerialization of %bcast
|
||||
// across that point would reduce peak memory use by 4KB. However, %bcast is
|
||||
// used indirectly in the %negate so rematerialization should not happen.
|
||||
// across that point would reduce peak memory use by 4KB.
|
||||
//
|
||||
// This test is parameterized on whether the broadcast has an indirect use or
|
||||
// not. The indirect use is controlled by the index of the GetTupleElement
|
||||
// This test is parameterized on whether the broadcast has an indirect use
|
||||
// or not. The indirect use is controlled by the index of the GetTupleElement
|
||||
// instruction. If the element is 0, then the %negate operand aliases %bcast
|
||||
// (ie %bcast is used indirectly by %negate), otherwise the %negate operand
|
||||
// aliases %add_2.
|
||||
@ -539,17 +537,17 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
||||
|
||||
EXPECT_EQ(entry_computation->instruction_count(), 8);
|
||||
|
||||
// Pick a memory limit some where between 24KB (initial peak memory including
|
||||
// parameter and output) and 20KB (peak memory possible with
|
||||
// Pick a memory limit some where between 24KB (initial peak memory
|
||||
// including parameter and output) and 20KB (peak memory possible with
|
||||
// rematerialization).
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/22 * 1024, module.get()));
|
||||
// Rematerialization should only occur if the rematerializable instruction has
|
||||
// no indirect uses.
|
||||
// Rematerialization should only occur if the rematerializable instruction
|
||||
// has no indirect uses.
|
||||
if (indirectly_used) {
|
||||
EXPECT_FALSE(changed);
|
||||
EXPECT_EQ(entry_computation->instruction_count(), 8);
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_EQ(entry_computation->instruction_count(), 3);
|
||||
} else {
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_EQ(entry_computation->instruction_count(), 9);
|
||||
@ -633,7 +631,7 @@ ENTRY %entry {
|
||||
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
|
||||
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
|
||||
ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2)
|
||||
}
|
||||
@ -847,6 +845,199 @@ ENTRY %mycomp (param: f32[1]) -> f32[1024] {
|
||||
EXPECT_TRUE(changed);
|
||||
}
|
||||
|
||||
TEST_F(HloRematerializationTest, RematTupleShape) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
%add_mul_comp {
|
||||
%p0 = f32[] parameter(0)
|
||||
%p1 = f32[] parameter(1)
|
||||
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
|
||||
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
|
||||
%add = f32[1024] add(%x, %y)
|
||||
%mul = f32[1024] multiply(%x, %y)
|
||||
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param.0 = f32[] parameter(0)
|
||||
%param.1 = f32[] parameter(1)
|
||||
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
|
||||
calls=%add_mul_comp
|
||||
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
|
||||
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
|
||||
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
|
||||
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
|
||||
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
|
||||
ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
const HloInstruction* add = computation->root_instruction();
|
||||
ASSERT_THAT(add, op::Add(op::Multiply(), op::GetTupleElement(op::Fusion())));
|
||||
const HloInstruction* fusion = add->operand(0)->operand(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/11 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
ASSERT_THAT(
|
||||
add, op::Add(op::Multiply(), op::GetTupleElement(AllOf(
|
||||
op::Fusion(), ::testing::Ne(fusion)))));
|
||||
}
|
||||
|
||||
TEST_F(HloRematerializationTest, RematTupleShapeDoubleUse) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
%add_mul_comp {
|
||||
%p0 = f32[] parameter(0)
|
||||
%p1 = f32[] parameter(1)
|
||||
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
|
||||
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
|
||||
%add = f32[1024] add(%x, %y)
|
||||
%mul = f32[1024] multiply(%x, %y)
|
||||
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param.0 = f32[] parameter(0)
|
||||
%param.1 = f32[] parameter(1)
|
||||
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
|
||||
calls=%add_mul_comp
|
||||
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
|
||||
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
|
||||
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
|
||||
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
|
||||
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
|
||||
%gte.3 = f32[1024]{0} get-tuple-element(%fus), index=0
|
||||
%add.2 = f32[1024]{0} add(f32[1024]{0} %mul, f32[1024]{0} %gte.2)
|
||||
ROOT %mul.2 = f32[1024]{0} multiply(f32[1024]{0} %add.2, f32[1024]{0} %gte.3)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
const HloInstruction* add = computation->root_instruction();
|
||||
ASSERT_THAT(add, op::Multiply(op::Add(op::Multiply(),
|
||||
op::GetTupleElement(op::Fusion())),
|
||||
op::GetTupleElement(op::Fusion())));
|
||||
const HloInstruction* fusion = add->operand(0)->operand(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/11 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
ASSERT_THAT(
|
||||
add,
|
||||
op::Multiply(
|
||||
op::Add(op::Multiply(), op::GetTupleElement(AllOf(
|
||||
op::Fusion(), ::testing::Ne(fusion)))),
|
||||
op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(fusion)))));
|
||||
// Check that the rematerialized fusion is the same for both ops.
|
||||
EXPECT_EQ(add->operand(0)->operand(1)->operand(0),
|
||||
add->operand(1)->operand(0));
|
||||
}
|
||||
|
||||
TEST_F(HloRematerializationTest, RematTupleShapeThroughBitcasts) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
%add_mul_comp {
|
||||
%p0 = f32[] parameter(0)
|
||||
%p1 = f32[] parameter(1)
|
||||
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
|
||||
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
|
||||
%add = f32[1024] add(%x, %y)
|
||||
%mul = f32[1024] multiply(%x, %y)
|
||||
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param.0 = f32[] parameter(0)
|
||||
%param.1 = f32[] parameter(1)
|
||||
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
|
||||
calls=%add_mul_comp
|
||||
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
|
||||
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.1)
|
||||
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
|
||||
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
|
||||
%gte.2 = f32[1024]{0} get-tuple-element(%fus), index=1
|
||||
%bc.1 = f32[1024,1]{0,1} bitcast(%mul)
|
||||
%bc.2 = f32[1024,1]{0,1} bitcast(%gte.2)
|
||||
ROOT %add.2 = f32[1024,1]{0,1} add(f32[1024,1]{0,1} %bc.1,
|
||||
f32[1024,1]{0,1} %bc.2)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
const HloInstruction* add = computation->root_instruction();
|
||||
ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()),
|
||||
op::Bitcast(op::GetTupleElement(op::Fusion()))));
|
||||
const HloInstruction* fusion = add->operand(0)->operand(0)->operand(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/11 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
ASSERT_THAT(add, op::Add(op::Bitcast(op::Multiply()),
|
||||
op::Bitcast(op::GetTupleElement(
|
||||
AllOf(op::Fusion(), ::testing::Ne(fusion))))));
|
||||
}
|
||||
|
||||
TEST_F(HloRematerializationTest, RematThroughTuple) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
%add_mul_comp {
|
||||
%p0 = f32[] parameter(0)
|
||||
%p1 = f32[] parameter(1)
|
||||
%x = f32[1024]{0} broadcast(f32[] %p0), dimensions={}
|
||||
%y = f32[1024]{0} broadcast(f32[] %p1), dimensions={}
|
||||
%add = f32[1024] add(%x, %y)
|
||||
%mul = f32[1024] multiply(%x, %y)
|
||||
ROOT %out = (f32[1024], f32[1024]) tuple(%add, %mul)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param.0 = f32[] parameter(0)
|
||||
%param.1 = f32[] parameter(1)
|
||||
%fus = (f32[1024]{0}, f32[1024]{0}) fusion(%param.0, %param.1), kind=kLoop,
|
||||
calls=%add_mul_comp
|
||||
%gte.1 = f32[1024]{0} get-tuple-element(%fus), index=0
|
||||
%gte.3 = f32[1024]{0} get-tuple-element(%fus), index=1
|
||||
%add = f32[1024]{0} add(f32[1024]{0} %gte.1, f32[1024]{0} %gte.3)
|
||||
%broadcast.1 = f32[1024]{0} broadcast(f32[] %param.0), dimensions={}
|
||||
%mul = f32[1024]{0} multiply(f32[1024]{0} %add, f32[1024]{0} %broadcast.1)
|
||||
%tpl = (f32[1024]{0}, f32[1024]{0}) tuple(%gte.1, %add)
|
||||
%bc.1 = f32[1024,1]{0,1} bitcast(%mul)
|
||||
%gte.2 = f32[1024]{0} get-tuple-element(%tpl), index=0
|
||||
ROOT %add.2 = f32[1024]{0} add(f32[1024]{0} %gte.2, f32[1024]{0} %add)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
const HloInstruction* add = computation->root_instruction();
|
||||
ASSERT_THAT(add, op::Add(op::GetTupleElement(
|
||||
op::Tuple(op::GetTupleElement(op::Fusion()), _)),
|
||||
op::Add()));
|
||||
const HloInstruction* tuple = add->operand(0)->operand(0);
|
||||
const HloInstruction* fusion = tuple->operand(0)->operand(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/11 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
ASSERT_THAT(
|
||||
add, op::Add(op::GetTupleElement(AllOf(op::Fusion(), ::testing::Ne(tuple),
|
||||
::testing::Ne(fusion))),
|
||||
op::Add()));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user