Rollback of [XLA] Allow hlo_rematerialization pass to look through bitcasts when doing rematerialization.
Causes hours-long compilation times on OOMing code using JAX PRNG. PiperOrigin-RevId: 345611840 Change-Id: I2968f4ffc733c27ae2c37e032954a7d9f93cd818
This commit is contained in:
parent
fedf3f45fb
commit
2797009193
@ -3776,6 +3776,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
|
||||
@ -16,10 +16,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
@ -153,7 +155,22 @@ struct Item {
|
||||
int64 position;
|
||||
};
|
||||
|
||||
// Data structure meant to record the user of the buffer defined from an Item.
|
||||
// It records also the operand_number from where such use derives, so that
|
||||
// indirect uses can be better identified (like for example a buffer used
|
||||
// through a bitcast).
|
||||
struct ItemUse {
|
||||
Item* user;
|
||||
int64 operand_number;
|
||||
|
||||
ItemUse(Item* user, int64 op_num) : user(user), operand_number(op_num) {}
|
||||
bool operator==(const ItemUse& other) const {
|
||||
return user == other.user && operand_number == other.operand_number;
|
||||
}
|
||||
};
|
||||
|
||||
using ItemList = absl::InlinedVector<Item*, 3>;
|
||||
using UsesList = absl::InlinedVector<ItemUse, 3>;
|
||||
|
||||
// Class which maintains an ordered list of instructions with fast insertion
|
||||
// before arbitrary elements.
|
||||
@ -412,11 +429,11 @@ class InstructionList {
|
||||
// has_indirect_users to whether any of the uses is indirect. A use is indirect
|
||||
// if the instruction defining logical_buffer is not an operand of the use. This
|
||||
// can happen via buffer aliasing (eg, tuples).
|
||||
ItemList GetUsers(const InstructionList& instruction_list,
|
||||
UsesList GetUsers(const InstructionList& instruction_list,
|
||||
const LogicalBuffer* logical_buffer,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
bool* has_indirect_users) {
|
||||
ItemList users;
|
||||
UsesList users;
|
||||
// To identify uses iterate through all HloInstruction users of the
|
||||
// BufferAliases of the logical buffer.
|
||||
*has_indirect_users = false;
|
||||
@ -431,14 +448,18 @@ ItemList GetUsers(const InstructionList& instruction_list,
|
||||
// instruction (the GTE instruction only uses the pointer vector).
|
||||
continue;
|
||||
}
|
||||
if (buffer_alias.instruction() != logical_buffer->instruction()) {
|
||||
if (buffer_alias.instruction() != logical_buffer->instruction() &&
|
||||
buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) {
|
||||
*has_indirect_users = true;
|
||||
}
|
||||
// A buffer may be used by the instruction via more than one alias. For
|
||||
// example, a buffer which appears in more than one element of a tuple.
|
||||
Item* user_item = instruction_list.GetItem(user);
|
||||
if (!absl::c_linear_search(users, user_item)) {
|
||||
users.push_back(user_item);
|
||||
for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
|
||||
if (!absl::c_linear_search(
|
||||
users, ItemUse{user_item, static_cast<int>(op_idx)})) {
|
||||
users.push_back(ItemUse{user_item, static_cast<int>(op_idx)});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -516,7 +537,8 @@ class MemoryUsageTracker {
|
||||
// is remat_item. This method should be called after the HLO graph has
|
||||
// been transformed (rematerialization instruction created and connected
|
||||
// to uses).
|
||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
|
||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
|
||||
absl::Span<Item*> bitcasts);
|
||||
|
||||
// Selects and returns the best candidate instructions for rematerialization.
|
||||
// A sequence of candidate instructions of length between min_block_size and
|
||||
@ -538,6 +560,9 @@ class MemoryUsageTracker {
|
||||
// Returns whether 'item' has any unplaced users.
|
||||
bool HasUnplacedUsers(Item* item) const;
|
||||
|
||||
// Returns the list of uses for a specific 'item'.
|
||||
const UsesList GetItemUses(Item* item) const;
|
||||
|
||||
// Returns whether 'item' is currently in progress.
|
||||
bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
|
||||
|
||||
@ -588,7 +613,7 @@ class MemoryUsageTracker {
|
||||
bool has_indirect_uses;
|
||||
|
||||
// The instructions which use this buffer.
|
||||
ItemList users;
|
||||
UsesList users;
|
||||
|
||||
// The number of users (HloInstructions) of this buffer which have not yet
|
||||
// been placed in the sequence.
|
||||
@ -611,7 +636,7 @@ class MemoryUsageTracker {
|
||||
const LogicalBuffer* logical_buffer,
|
||||
const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
|
||||
bool has_indirect_uses = false;
|
||||
ItemList users = GetUsers(instruction_list_, logical_buffer,
|
||||
UsesList users = GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &has_indirect_uses);
|
||||
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
||||
logical_buffer->shape(), std::move(users), live_out,
|
||||
@ -621,13 +646,13 @@ class MemoryUsageTracker {
|
||||
// Create a new buffer representing a rematerialization of given buffer for
|
||||
// the given uses.
|
||||
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
|
||||
ItemList&& rematerialized_uses) {
|
||||
UsesList&& rematerialized_uses) {
|
||||
CHECK(original_buffer.defining_instruction->placed)
|
||||
<< original_buffer.defining_instruction->instruction->name();
|
||||
CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
|
||||
CHECK(!original_buffer.live_out) << original_buffer.ToString();
|
||||
for (Item* use : rematerialized_uses) {
|
||||
CHECK(!use->placed) << use->instruction->name();
|
||||
for (ItemUse& use : rematerialized_uses) {
|
||||
CHECK(!use.user->placed) << use.user->instruction->name();
|
||||
}
|
||||
return NewBuffer(remat_item, original_buffer.shape,
|
||||
std::move(rematerialized_uses), /*live_out=*/false,
|
||||
@ -665,8 +690,6 @@ class MemoryUsageTracker {
|
||||
return absl::c_linear_search(in_progress_uses, buffer_id);
|
||||
}
|
||||
|
||||
// Returns whether the given buffer is live at the current program
|
||||
// point.
|
||||
bool IsCurrentlyLive(BufferId buffer_id) const {
|
||||
const Buffer& buffer = buffers_[buffer_id];
|
||||
return (buffer.defining_instruction->placed &&
|
||||
@ -692,11 +715,18 @@ class MemoryUsageTracker {
|
||||
|
||||
// Create a new buffer, add it to buffers_, and return a reference.
|
||||
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
|
||||
ItemList&& users, bool live_out, bool has_indirect_uses) {
|
||||
UsesList&& uses, bool live_out, bool has_indirect_uses) {
|
||||
int buffer_id = buffers_.size();
|
||||
auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
|
||||
absl::flat_hash_set<Item*> users_set;
|
||||
for (const ItemUse& use : uses) {
|
||||
users_set.insert(use.user);
|
||||
}
|
||||
return users_set.size();
|
||||
};
|
||||
buffers_.push_back(Buffer{
|
||||
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
|
||||
has_indirect_uses, users, static_cast<int64>(users.size())});
|
||||
has_indirect_uses, uses, get_num_of_unique_users(uses)});
|
||||
return buffers_.back();
|
||||
}
|
||||
|
||||
@ -771,12 +801,15 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
|
||||
// Add users of while to Buffer users.
|
||||
bool unused;
|
||||
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &unused)) {
|
||||
if (!absl::c_linear_search(buffer->users, user_item)) {
|
||||
buffer->users.push_back(user_item);
|
||||
for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &unused)) {
|
||||
auto existing_user_it = absl::c_find_if(
|
||||
buffer->users,
|
||||
[&](const ItemUse& use) { return user_item.user == use.user; });
|
||||
if (existing_user_it == buffer->users.end()) {
|
||||
buffer->unfinished_user_count++;
|
||||
user_item->buffers_used.push_back(buffer->id);
|
||||
user_item.user->buffers_used.push_back(buffer->id);
|
||||
buffer->users.push_back(user_item);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -784,8 +817,10 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
logical_buffer, points_to_analysis,
|
||||
ContainsKey(live_out_set, logical_buffer));
|
||||
item->buffers_defined.push_back(buffer->id);
|
||||
for (Item* user : buffer->users) {
|
||||
user->buffers_used.push_back(buffer->id);
|
||||
for (ItemUse& user : buffer->users) {
|
||||
if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
|
||||
user.user->buffers_used.push_back(buffer->id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1003,14 +1038,14 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
// Compressed buffer is now alive.
|
||||
memory_usage_ += size_function_(compressed_item->instruction->shape());
|
||||
|
||||
ItemList placed_users;
|
||||
ItemList unplaced_users;
|
||||
UsesList placed_users;
|
||||
UsesList unplaced_users;
|
||||
CHECK_EQ(original_item->buffers_output.size(), 1);
|
||||
BufferId original_buffer_id = original_item->buffers_output[0];
|
||||
Buffer& original_buffer = buffers_.at(original_buffer_id);
|
||||
for (Item* user : original_buffer.users) {
|
||||
if (user->placed) {
|
||||
CHECK(IsFinished(user)) << user->instruction->name();
|
||||
for (ItemUse& user : original_buffer.users) {
|
||||
if (user.user->placed) {
|
||||
CHECK(IsFinished(user.user)) << user.user->instruction->name();
|
||||
placed_users.push_back(user);
|
||||
} else {
|
||||
unplaced_users.push_back(user);
|
||||
@ -1018,10 +1053,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
}
|
||||
original_buffer.users = std::move(placed_users);
|
||||
original_buffer.unfinished_user_count = 0;
|
||||
original_buffer.users.push_back(compressed_item);
|
||||
original_buffer.users.push_back(ItemUse{compressed_item, 0});
|
||||
Buffer& compressed_buffer =
|
||||
NewBuffer(compressed_item, compressed_item->instruction->shape(),
|
||||
{uncompressed_item}, /*live_out=*/false,
|
||||
{ItemUse{uncompressed_item, 0}}, /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
compressed_item->buffers_used = original_item->buffers_output;
|
||||
compressed_item->buffers_output = {compressed_buffer.id};
|
||||
@ -1036,8 +1071,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
uncompressed_item->buffers_output = {uncompressed_buffer.id};
|
||||
uncompressed_item->buffers_defined = {uncompressed_buffer.id};
|
||||
|
||||
for (Item* user : uncompressed_buffer.users) {
|
||||
BufferIdList& buffers_used = user->buffers_used;
|
||||
for (ItemUse& user : uncompressed_buffer.users) {
|
||||
BufferIdList& buffers_used = user.user->buffers_used;
|
||||
std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
|
||||
uncompressed_buffer.id);
|
||||
}
|
||||
@ -1045,8 +1080,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||
Item* remat_item) {
|
||||
Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) {
|
||||
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
||||
<< original_item->instruction->name()
|
||||
<< ", remat_instruction = " << remat_item->instruction->name();
|
||||
@ -1067,9 +1102,23 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||
// Buffer used by this instruction was dead, now is alive.
|
||||
memory_usage_ += AllocatedSize(buffer.id);
|
||||
}
|
||||
|
||||
buffer.unfinished_user_count++;
|
||||
buffer.users.push_back(remat_item);
|
||||
absl::InlinedVector<ItemUse, 2> filtered_users;
|
||||
std::copy_if(buffer.users.begin(), buffer.users.end(),
|
||||
std::back_inserter(filtered_users),
|
||||
[&](const ItemUse& iu) { return iu.user == original_item; });
|
||||
for (ItemUse& u : filtered_users) {
|
||||
buffer.users.push_back(ItemUse{remat_item, u.operand_number});
|
||||
}
|
||||
}
|
||||
|
||||
for (Item* bitcast : bitcasts) {
|
||||
CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast);
|
||||
for (BufferId buffer_id : bitcast->buffers_used) {
|
||||
Buffer& buffer = buffers_.at(buffer_id);
|
||||
buffer.unfinished_user_count++;
|
||||
buffer.users.push_back(ItemUse{bitcast, 0});
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new set of Buffers defined by the new rematerialization
|
||||
@ -1078,10 +1127,10 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||
for (BufferId old_buffer_id : original_item->buffers_defined) {
|
||||
Buffer& old_buffer = buffers_.at(old_buffer_id);
|
||||
|
||||
ItemList placed_users;
|
||||
ItemList unplaced_users;
|
||||
for (Item* user : old_buffer.users) {
|
||||
if (user->placed) {
|
||||
UsesList placed_users;
|
||||
UsesList unplaced_users;
|
||||
for (ItemUse& user : old_buffer.users) {
|
||||
if (user.user->placed) {
|
||||
placed_users.push_back(user);
|
||||
} else {
|
||||
unplaced_users.push_back(user);
|
||||
@ -1097,8 +1146,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
|
||||
|
||||
remat_item->buffers_defined.push_back(new_buffer.id);
|
||||
for (Item* user : new_buffer.users) {
|
||||
BufferIdList& buffers_used = user->buffers_used;
|
||||
for (ItemUse& user : new_buffer.users) {
|
||||
BufferIdList& buffers_used = user.user->buffers_used;
|
||||
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
||||
new_buffer.id);
|
||||
}
|
||||
@ -1131,6 +1180,10 @@ string MemoryUsageTracker::ToString() const {
|
||||
absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
|
||||
buffer.unfinished_user_count, " unfinished uses\n");
|
||||
}
|
||||
absl::StrAppend(&output, " Outputs:\n");
|
||||
for (BufferId buffer_id : item->buffers_output) {
|
||||
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
|
||||
}
|
||||
absl::StrAppend(&output, " Uses:\n");
|
||||
for (BufferId buffer_id : item->buffers_used) {
|
||||
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
|
||||
@ -1190,12 +1243,14 @@ bool MemoryUsageTracker::Check() const {
|
||||
}
|
||||
for (const Buffer& buffer : buffers_) {
|
||||
int64 unfinished_uses = 0;
|
||||
for (Item* user : buffer.users) {
|
||||
const BufferIdList& used_buffers = user->buffers_used;
|
||||
absl::flat_hash_set<Item*> already_counted_user;
|
||||
for (const ItemUse& user : buffer.users) {
|
||||
const BufferIdList& used_buffers = user.user->buffers_used;
|
||||
CHECK(absl::c_linear_search(used_buffers, buffer.id))
|
||||
<< "Instruction " << user->instruction->name()
|
||||
<< "Instruction " << user.user->instruction->name()
|
||||
<< " used buffers is missing " << buffer.ToString();
|
||||
if (!IsFinished(user)) {
|
||||
if (!IsFinished(user.user) &&
|
||||
already_counted_user.insert(user.user).second) {
|
||||
unfinished_uses++;
|
||||
}
|
||||
}
|
||||
@ -1397,8 +1452,8 @@ MemoryUsageTracker::PickRematerializationCandidates(
|
||||
bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
for (Item* user : buffer.users) {
|
||||
if (!user->placed) {
|
||||
for (const ItemUse& user : buffer.users) {
|
||||
if (!user.user->placed) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -1406,6 +1461,17 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
|
||||
UsesList combined_users;
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
for (const ItemUse& user : buffer.users) {
|
||||
combined_users.push_back(user);
|
||||
}
|
||||
}
|
||||
return combined_users;
|
||||
}
|
||||
|
||||
StatusOr<int64> RematerializeInstructions(
|
||||
MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
|
||||
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
|
||||
@ -1443,18 +1509,30 @@ StatusOr<int64> RematerializeInstructions(
|
||||
Item* remat_item = instruction_list->CreateItem(remat);
|
||||
|
||||
// Replace each remaining use of 'best' with the rematerialization.
|
||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
||||
for (HloInstruction* user : best_users_copy) {
|
||||
if (!memory_tracker->IsPlaced(user)) {
|
||||
absl::InlinedVector<Item*, 4> bitcasts;
|
||||
for (auto& user : memory_tracker->GetItemUses(best_item)) {
|
||||
if (!memory_tracker->IsPlaced(user.user->instruction)) {
|
||||
VLOG(2) << " Replacing use of " << best->name() << " in "
|
||||
<< user->name() << " with " << remat->name();
|
||||
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
|
||||
<< user.user->instruction->name() << " with " << remat->name();
|
||||
const int64 op_idx = user.operand_number;
|
||||
auto* remat_use = remat;
|
||||
if (user.user->instruction->operand(op_idx)->shape() !=
|
||||
remat->shape()) {
|
||||
remat_use = computation->AddInstruction(HloInstruction::CreateUnary(
|
||||
user.user->instruction->operand(op_idx)->shape(),
|
||||
HloOpcode::kBitcast, remat));
|
||||
bitcasts.push_back(instruction_list->CreateItem(remat_use));
|
||||
bitcasts.back()->buffers_output = remat_item->buffers_defined;
|
||||
bitcasts.back()->buffers_used = remat_item->buffers_defined;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
|
||||
}
|
||||
}
|
||||
|
||||
// Account for the rematerialization in the memory tracker.
|
||||
TF_RETURN_IF_ERROR(
|
||||
memory_tracker->AddRematerializedInstruction(best_item, remat_item));
|
||||
TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
|
||||
best_item, remat_item, absl::MakeSpan(bitcasts)));
|
||||
|
||||
// Insert rematerialized instruction right before the earliest unplaced
|
||||
// use of the instruction *and* the earliest unplaced last use of any
|
||||
@ -1463,7 +1541,14 @@ StatusOr<int64> RematerializeInstructions(
|
||||
// this could increase memory usage.
|
||||
ItemList place_before;
|
||||
for (auto user : remat->users()) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
}
|
||||
for (auto* bitcast : bitcasts) {
|
||||
for (auto user : bitcast->instruction->users()) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
}
|
||||
for (auto* operand : remat->operands()) {
|
||||
for (auto* operand_user : operand->users()) {
|
||||
@ -1486,12 +1571,25 @@ StatusOr<int64> RematerializeInstructions(
|
||||
}
|
||||
instruction_list->InsertBeforeInstructions(remat_item, place_before);
|
||||
|
||||
for (auto* bitcast : bitcasts) {
|
||||
instruction_list->InsertBeforeInstructions(bitcast, place_before);
|
||||
}
|
||||
// Helper function that looks through bitcasts when determining if there
|
||||
// is an active user for an HloInstruction.
|
||||
std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
|
||||
for (auto* u : i->users()) {
|
||||
if (u->opcode() != HloOpcode::kBitcast || !uses_empty(u)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
// If the rematerialized instruction is dead then rematerialization is
|
||||
// essentially a move. Don't delete the instruction now because we don't
|
||||
// want duplicate HloInstruction* values during the course of the
|
||||
// transformation because we keep maps with HloInstruction* values as
|
||||
// keys.
|
||||
if (best->users().empty()) {
|
||||
if (uses_empty(best)) {
|
||||
VLOG(2) << best->name() << " is now dead";
|
||||
if (ContainsKey(*remat_move_instructions, best)) {
|
||||
// Previously, 'best' was a rematerialization which killed the
|
||||
@ -1501,8 +1599,12 @@ StatusOr<int64> RematerializeInstructions(
|
||||
instruction_list->Denylist(remat);
|
||||
}
|
||||
remat_move_instructions->insert(remat);
|
||||
net_instructions_added += bitcasts.size();
|
||||
} else {
|
||||
net_instructions_added++;
|
||||
net_instructions_added += bitcasts.size() + 1;
|
||||
}
|
||||
for (auto* bitcast : bitcasts) {
|
||||
instruction_list->Denylist(bitcast->instruction);
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Rematerializing instructions ["
|
||||
|
||||
@ -748,6 +748,105 @@ ENTRY %entry {
|
||||
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
|
||||
}
|
||||
|
||||
// Test rematerialization of values through bitcasts
|
||||
// Its expected that the broadcast gets rematerialized
|
||||
TEST_F(HloRematerializationTest, ThroughBitcastRemat) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
ENTRY %mycomp (param: f32[1]) -> f32[1] {
|
||||
%param = f32[1]{0} parameter(0)
|
||||
%reshape = f32[] reshape(f32[1]{0} %param)
|
||||
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
|
||||
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
|
||||
%negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %broadcast)
|
||||
%concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate, f32[1024,1]{1,0} %negate), dimensions={0}
|
||||
%slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate), slice={[0:1], [0:1]}
|
||||
%bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
|
||||
%concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast, f32[1]{0} %bitcast.1), dimensions={0}
|
||||
ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto* computation = module->entry_computation();
|
||||
// Find and save the original broadcast instruction which should be
|
||||
// rematerialized.
|
||||
const HloInstruction* slice = computation->root_instruction();
|
||||
ASSERT_THAT(slice,
|
||||
op::Slice(op::Concatenate(op::Bitcast(op::Broadcast(_)), _)));
|
||||
const HloInstruction* concat = slice->operand(0);
|
||||
const HloInstruction* bcast = concat->operand(0)->operand(0);
|
||||
|
||||
// Computation requires 16KB without rematerialization, but uses only 12KB
|
||||
// with rematerialization so pick a memory limit between these values (14KB).
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/14 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// Root should not have changed.
|
||||
EXPECT_EQ(computation->root_instruction(), slice);
|
||||
|
||||
// The bitcast for the rematerialized broadcast
|
||||
const HloInstruction* remat_bitcast = concat->operand(0);
|
||||
// The broadcast should have been rematerialized.
|
||||
const HloInstruction* remat_broadcast = remat_bitcast->operand(0);
|
||||
|
||||
EXPECT_THAT(remat_broadcast, op::Broadcast(::testing::Ne(bcast)));
|
||||
|
||||
// The rematerialized broadcast should be immediately before its bitcast
|
||||
// and the bitcast before the concatenate in the sequence.
|
||||
EXPECT_EQ(module->schedule()
|
||||
.sequence(computation)
|
||||
.instructions()[computation->instruction_count() - 2],
|
||||
concat);
|
||||
EXPECT_EQ(module->schedule()
|
||||
.sequence(computation)
|
||||
.instructions()[computation->instruction_count() - 3],
|
||||
remat_bitcast);
|
||||
EXPECT_EQ(module->schedule()
|
||||
.sequence(computation)
|
||||
.instructions()[computation->instruction_count() - 4],
|
||||
remat_broadcast);
|
||||
}
|
||||
|
||||
// Test that the "deny list for move remats" engages when we rematerialize
|
||||
// through bitcasts.
|
||||
TEST_F(HloRematerializationTest, ThroughBitcastRematInfiniteLoop) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
ENTRY %mycomp (param: f32[1]) -> f32[1024] {
|
||||
%param = f32[1]{0} parameter(0)
|
||||
%reshape = f32[] reshape(f32[1]{0} %param)
|
||||
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
|
||||
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
|
||||
%broadcast2 = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
|
||||
%bitcast2 = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast2)
|
||||
ROOT %add = f32[1024]{0} add(f32[1024]{0} %bitcast, f32[1024]{0} %bitcast2)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto* computation = module->entry_computation();
|
||||
// Find and save the original broadcasts instruction which should be
|
||||
// rematerialized.
|
||||
const HloInstruction* add = computation->root_instruction();
|
||||
// Run with a low rematerialization limit that cannot be satisfied to make
|
||||
// sure that we don't get stuck in a loop trying to lower it.
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/1024, module.get()));
|
||||
ASSERT_THAT(add, op::Add(op::Bitcast(op::Broadcast(_)),
|
||||
op::Bitcast(op::Broadcast(_))));
|
||||
EXPECT_TRUE(changed);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user