Rollback of [XLA] Allow hlo_rematerialization pass to look through bitcasts when doing rematerialization.
Causes hours-long compilation times on OOMing code using JAX PRNG. PiperOrigin-RevId: 345611840 Change-Id: I2968f4ffc733c27ae2c37e032954a7d9f93cd818
This commit is contained in:
parent
fedf3f45fb
commit
2797009193
@ -3776,6 +3776,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
|
|||||||
@ -16,10 +16,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
|
#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <iterator>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
@ -153,7 +155,22 @@ struct Item {
|
|||||||
int64 position;
|
int64 position;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Data structure meant to record the user of the buffer defined from an Item.
|
||||||
|
// It records also the operand_number from where such use derives, so that
|
||||||
|
// indirect uses can be better identified (like for example a buffer used
|
||||||
|
// through a bitcast).
|
||||||
|
struct ItemUse {
|
||||||
|
Item* user;
|
||||||
|
int64 operand_number;
|
||||||
|
|
||||||
|
ItemUse(Item* user, int64 op_num) : user(user), operand_number(op_num) {}
|
||||||
|
bool operator==(const ItemUse& other) const {
|
||||||
|
return user == other.user && operand_number == other.operand_number;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
using ItemList = absl::InlinedVector<Item*, 3>;
|
using ItemList = absl::InlinedVector<Item*, 3>;
|
||||||
|
using UsesList = absl::InlinedVector<ItemUse, 3>;
|
||||||
|
|
||||||
// Class which maintains an ordered list of instructions with fast insertion
|
// Class which maintains an ordered list of instructions with fast insertion
|
||||||
// before arbitrary elements.
|
// before arbitrary elements.
|
||||||
@ -412,11 +429,11 @@ class InstructionList {
|
|||||||
// has_indirect_users to whether any of the uses is indirect. A use is indirect
|
// has_indirect_users to whether any of the uses is indirect. A use is indirect
|
||||||
// if the instruction defining logical_buffer is not an operand of the use. This
|
// if the instruction defining logical_buffer is not an operand of the use. This
|
||||||
// can happen via buffer aliasing (eg, tuples).
|
// can happen via buffer aliasing (eg, tuples).
|
||||||
ItemList GetUsers(const InstructionList& instruction_list,
|
UsesList GetUsers(const InstructionList& instruction_list,
|
||||||
const LogicalBuffer* logical_buffer,
|
const LogicalBuffer* logical_buffer,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
bool* has_indirect_users) {
|
bool* has_indirect_users) {
|
||||||
ItemList users;
|
UsesList users;
|
||||||
// To identify uses iterate through all HloInstruction users of the
|
// To identify uses iterate through all HloInstruction users of the
|
||||||
// BufferAliases of the logical buffer.
|
// BufferAliases of the logical buffer.
|
||||||
*has_indirect_users = false;
|
*has_indirect_users = false;
|
||||||
@ -431,14 +448,18 @@ ItemList GetUsers(const InstructionList& instruction_list,
|
|||||||
// instruction (the GTE instruction only uses the pointer vector).
|
// instruction (the GTE instruction only uses the pointer vector).
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (buffer_alias.instruction() != logical_buffer->instruction()) {
|
if (buffer_alias.instruction() != logical_buffer->instruction() &&
|
||||||
|
buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) {
|
||||||
*has_indirect_users = true;
|
*has_indirect_users = true;
|
||||||
}
|
}
|
||||||
// A buffer may be used by the instruction via more than one alias. For
|
// A buffer may be used by the instruction via more than one alias. For
|
||||||
// example, a buffer which appears in more than one element of a tuple.
|
// example, a buffer which appears in more than one element of a tuple.
|
||||||
Item* user_item = instruction_list.GetItem(user);
|
Item* user_item = instruction_list.GetItem(user);
|
||||||
if (!absl::c_linear_search(users, user_item)) {
|
for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
|
||||||
users.push_back(user_item);
|
if (!absl::c_linear_search(
|
||||||
|
users, ItemUse{user_item, static_cast<int>(op_idx)})) {
|
||||||
|
users.push_back(ItemUse{user_item, static_cast<int>(op_idx)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,7 +537,8 @@ class MemoryUsageTracker {
|
|||||||
// is remat_item. This method should be called after the HLO graph has
|
// is remat_item. This method should be called after the HLO graph has
|
||||||
// been transformed (rematerialization instruction created and connected
|
// been transformed (rematerialization instruction created and connected
|
||||||
// to uses).
|
// to uses).
|
||||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
|
Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
|
||||||
|
absl::Span<Item*> bitcasts);
|
||||||
|
|
||||||
// Selects and returns the best candidate instructions for rematerialization.
|
// Selects and returns the best candidate instructions for rematerialization.
|
||||||
// A sequence of candidate instructions of length between min_block_size and
|
// A sequence of candidate instructions of length between min_block_size and
|
||||||
@ -538,6 +560,9 @@ class MemoryUsageTracker {
|
|||||||
// Returns whether 'item' has any unplaced users.
|
// Returns whether 'item' has any unplaced users.
|
||||||
bool HasUnplacedUsers(Item* item) const;
|
bool HasUnplacedUsers(Item* item) const;
|
||||||
|
|
||||||
|
// Returns the list of uses for a specific 'item'.
|
||||||
|
const UsesList GetItemUses(Item* item) const;
|
||||||
|
|
||||||
// Returns whether 'item' is currently in progress.
|
// Returns whether 'item' is currently in progress.
|
||||||
bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
|
bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
|
||||||
|
|
||||||
@ -588,7 +613,7 @@ class MemoryUsageTracker {
|
|||||||
bool has_indirect_uses;
|
bool has_indirect_uses;
|
||||||
|
|
||||||
// The instructions which use this buffer.
|
// The instructions which use this buffer.
|
||||||
ItemList users;
|
UsesList users;
|
||||||
|
|
||||||
// The number of users (HloInstructions) of this buffer which have not yet
|
// The number of users (HloInstructions) of this buffer which have not yet
|
||||||
// been placed in the sequence.
|
// been placed in the sequence.
|
||||||
@ -611,7 +636,7 @@ class MemoryUsageTracker {
|
|||||||
const LogicalBuffer* logical_buffer,
|
const LogicalBuffer* logical_buffer,
|
||||||
const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
|
const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
|
||||||
bool has_indirect_uses = false;
|
bool has_indirect_uses = false;
|
||||||
ItemList users = GetUsers(instruction_list_, logical_buffer,
|
UsesList users = GetUsers(instruction_list_, logical_buffer,
|
||||||
points_to_analysis, &has_indirect_uses);
|
points_to_analysis, &has_indirect_uses);
|
||||||
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
||||||
logical_buffer->shape(), std::move(users), live_out,
|
logical_buffer->shape(), std::move(users), live_out,
|
||||||
@ -621,13 +646,13 @@ class MemoryUsageTracker {
|
|||||||
// Create a new buffer representing a rematerialization of given buffer for
|
// Create a new buffer representing a rematerialization of given buffer for
|
||||||
// the given uses.
|
// the given uses.
|
||||||
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
|
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
|
||||||
ItemList&& rematerialized_uses) {
|
UsesList&& rematerialized_uses) {
|
||||||
CHECK(original_buffer.defining_instruction->placed)
|
CHECK(original_buffer.defining_instruction->placed)
|
||||||
<< original_buffer.defining_instruction->instruction->name();
|
<< original_buffer.defining_instruction->instruction->name();
|
||||||
CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
|
CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
|
||||||
CHECK(!original_buffer.live_out) << original_buffer.ToString();
|
CHECK(!original_buffer.live_out) << original_buffer.ToString();
|
||||||
for (Item* use : rematerialized_uses) {
|
for (ItemUse& use : rematerialized_uses) {
|
||||||
CHECK(!use->placed) << use->instruction->name();
|
CHECK(!use.user->placed) << use.user->instruction->name();
|
||||||
}
|
}
|
||||||
return NewBuffer(remat_item, original_buffer.shape,
|
return NewBuffer(remat_item, original_buffer.shape,
|
||||||
std::move(rematerialized_uses), /*live_out=*/false,
|
std::move(rematerialized_uses), /*live_out=*/false,
|
||||||
@ -665,8 +690,6 @@ class MemoryUsageTracker {
|
|||||||
return absl::c_linear_search(in_progress_uses, buffer_id);
|
return absl::c_linear_search(in_progress_uses, buffer_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the given buffer is live at the current program
|
|
||||||
// point.
|
|
||||||
bool IsCurrentlyLive(BufferId buffer_id) const {
|
bool IsCurrentlyLive(BufferId buffer_id) const {
|
||||||
const Buffer& buffer = buffers_[buffer_id];
|
const Buffer& buffer = buffers_[buffer_id];
|
||||||
return (buffer.defining_instruction->placed &&
|
return (buffer.defining_instruction->placed &&
|
||||||
@ -692,11 +715,18 @@ class MemoryUsageTracker {
|
|||||||
|
|
||||||
// Create a new buffer, add it to buffers_, and return a reference.
|
// Create a new buffer, add it to buffers_, and return a reference.
|
||||||
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
|
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
|
||||||
ItemList&& users, bool live_out, bool has_indirect_uses) {
|
UsesList&& uses, bool live_out, bool has_indirect_uses) {
|
||||||
int buffer_id = buffers_.size();
|
int buffer_id = buffers_.size();
|
||||||
|
auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
|
||||||
|
absl::flat_hash_set<Item*> users_set;
|
||||||
|
for (const ItemUse& use : uses) {
|
||||||
|
users_set.insert(use.user);
|
||||||
|
}
|
||||||
|
return users_set.size();
|
||||||
|
};
|
||||||
buffers_.push_back(Buffer{
|
buffers_.push_back(Buffer{
|
||||||
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
|
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
|
||||||
has_indirect_uses, users, static_cast<int64>(users.size())});
|
has_indirect_uses, uses, get_num_of_unique_users(uses)});
|
||||||
return buffers_.back();
|
return buffers_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -771,12 +801,15 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
|
|
||||||
// Add users of while to Buffer users.
|
// Add users of while to Buffer users.
|
||||||
bool unused;
|
bool unused;
|
||||||
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
|
||||||
points_to_analysis, &unused)) {
|
points_to_analysis, &unused)) {
|
||||||
if (!absl::c_linear_search(buffer->users, user_item)) {
|
auto existing_user_it = absl::c_find_if(
|
||||||
buffer->users.push_back(user_item);
|
buffer->users,
|
||||||
|
[&](const ItemUse& use) { return user_item.user == use.user; });
|
||||||
|
if (existing_user_it == buffer->users.end()) {
|
||||||
buffer->unfinished_user_count++;
|
buffer->unfinished_user_count++;
|
||||||
user_item->buffers_used.push_back(buffer->id);
|
user_item.user->buffers_used.push_back(buffer->id);
|
||||||
|
buffer->users.push_back(user_item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -784,8 +817,10 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
logical_buffer, points_to_analysis,
|
logical_buffer, points_to_analysis,
|
||||||
ContainsKey(live_out_set, logical_buffer));
|
ContainsKey(live_out_set, logical_buffer));
|
||||||
item->buffers_defined.push_back(buffer->id);
|
item->buffers_defined.push_back(buffer->id);
|
||||||
for (Item* user : buffer->users) {
|
for (ItemUse& user : buffer->users) {
|
||||||
user->buffers_used.push_back(buffer->id);
|
if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
|
||||||
|
user.user->buffers_used.push_back(buffer->id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1003,14 +1038,14 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
|||||||
// Compressed buffer is now alive.
|
// Compressed buffer is now alive.
|
||||||
memory_usage_ += size_function_(compressed_item->instruction->shape());
|
memory_usage_ += size_function_(compressed_item->instruction->shape());
|
||||||
|
|
||||||
ItemList placed_users;
|
UsesList placed_users;
|
||||||
ItemList unplaced_users;
|
UsesList unplaced_users;
|
||||||
CHECK_EQ(original_item->buffers_output.size(), 1);
|
CHECK_EQ(original_item->buffers_output.size(), 1);
|
||||||
BufferId original_buffer_id = original_item->buffers_output[0];
|
BufferId original_buffer_id = original_item->buffers_output[0];
|
||||||
Buffer& original_buffer = buffers_.at(original_buffer_id);
|
Buffer& original_buffer = buffers_.at(original_buffer_id);
|
||||||
for (Item* user : original_buffer.users) {
|
for (ItemUse& user : original_buffer.users) {
|
||||||
if (user->placed) {
|
if (user.user->placed) {
|
||||||
CHECK(IsFinished(user)) << user->instruction->name();
|
CHECK(IsFinished(user.user)) << user.user->instruction->name();
|
||||||
placed_users.push_back(user);
|
placed_users.push_back(user);
|
||||||
} else {
|
} else {
|
||||||
unplaced_users.push_back(user);
|
unplaced_users.push_back(user);
|
||||||
@ -1018,10 +1053,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
|||||||
}
|
}
|
||||||
original_buffer.users = std::move(placed_users);
|
original_buffer.users = std::move(placed_users);
|
||||||
original_buffer.unfinished_user_count = 0;
|
original_buffer.unfinished_user_count = 0;
|
||||||
original_buffer.users.push_back(compressed_item);
|
original_buffer.users.push_back(ItemUse{compressed_item, 0});
|
||||||
Buffer& compressed_buffer =
|
Buffer& compressed_buffer =
|
||||||
NewBuffer(compressed_item, compressed_item->instruction->shape(),
|
NewBuffer(compressed_item, compressed_item->instruction->shape(),
|
||||||
{uncompressed_item}, /*live_out=*/false,
|
{ItemUse{uncompressed_item, 0}}, /*live_out=*/false,
|
||||||
/*has_indirect_uses=*/false);
|
/*has_indirect_uses=*/false);
|
||||||
compressed_item->buffers_used = original_item->buffers_output;
|
compressed_item->buffers_used = original_item->buffers_output;
|
||||||
compressed_item->buffers_output = {compressed_buffer.id};
|
compressed_item->buffers_output = {compressed_buffer.id};
|
||||||
@ -1036,8 +1071,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
|||||||
uncompressed_item->buffers_output = {uncompressed_buffer.id};
|
uncompressed_item->buffers_output = {uncompressed_buffer.id};
|
||||||
uncompressed_item->buffers_defined = {uncompressed_buffer.id};
|
uncompressed_item->buffers_defined = {uncompressed_buffer.id};
|
||||||
|
|
||||||
for (Item* user : uncompressed_buffer.users) {
|
for (ItemUse& user : uncompressed_buffer.users) {
|
||||||
BufferIdList& buffers_used = user->buffers_used;
|
BufferIdList& buffers_used = user.user->buffers_used;
|
||||||
std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
|
std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
|
||||||
uncompressed_buffer.id);
|
uncompressed_buffer.id);
|
||||||
}
|
}
|
||||||
@ -1045,8 +1080,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||||
Item* remat_item) {
|
Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) {
|
||||||
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
||||||
<< original_item->instruction->name()
|
<< original_item->instruction->name()
|
||||||
<< ", remat_instruction = " << remat_item->instruction->name();
|
<< ", remat_instruction = " << remat_item->instruction->name();
|
||||||
@ -1067,9 +1102,23 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
|||||||
// Buffer used by this instruction was dead, now is alive.
|
// Buffer used by this instruction was dead, now is alive.
|
||||||
memory_usage_ += AllocatedSize(buffer.id);
|
memory_usage_ += AllocatedSize(buffer.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer.unfinished_user_count++;
|
buffer.unfinished_user_count++;
|
||||||
buffer.users.push_back(remat_item);
|
absl::InlinedVector<ItemUse, 2> filtered_users;
|
||||||
|
std::copy_if(buffer.users.begin(), buffer.users.end(),
|
||||||
|
std::back_inserter(filtered_users),
|
||||||
|
[&](const ItemUse& iu) { return iu.user == original_item; });
|
||||||
|
for (ItemUse& u : filtered_users) {
|
||||||
|
buffer.users.push_back(ItemUse{remat_item, u.operand_number});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Item* bitcast : bitcasts) {
|
||||||
|
CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast);
|
||||||
|
for (BufferId buffer_id : bitcast->buffers_used) {
|
||||||
|
Buffer& buffer = buffers_.at(buffer_id);
|
||||||
|
buffer.unfinished_user_count++;
|
||||||
|
buffer.users.push_back(ItemUse{bitcast, 0});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new set of Buffers defined by the new rematerialization
|
// Create a new set of Buffers defined by the new rematerialization
|
||||||
@ -1078,10 +1127,10 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
|||||||
for (BufferId old_buffer_id : original_item->buffers_defined) {
|
for (BufferId old_buffer_id : original_item->buffers_defined) {
|
||||||
Buffer& old_buffer = buffers_.at(old_buffer_id);
|
Buffer& old_buffer = buffers_.at(old_buffer_id);
|
||||||
|
|
||||||
ItemList placed_users;
|
UsesList placed_users;
|
||||||
ItemList unplaced_users;
|
UsesList unplaced_users;
|
||||||
for (Item* user : old_buffer.users) {
|
for (ItemUse& user : old_buffer.users) {
|
||||||
if (user->placed) {
|
if (user.user->placed) {
|
||||||
placed_users.push_back(user);
|
placed_users.push_back(user);
|
||||||
} else {
|
} else {
|
||||||
unplaced_users.push_back(user);
|
unplaced_users.push_back(user);
|
||||||
@ -1097,8 +1146,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
|||||||
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
|
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
|
||||||
|
|
||||||
remat_item->buffers_defined.push_back(new_buffer.id);
|
remat_item->buffers_defined.push_back(new_buffer.id);
|
||||||
for (Item* user : new_buffer.users) {
|
for (ItemUse& user : new_buffer.users) {
|
||||||
BufferIdList& buffers_used = user->buffers_used;
|
BufferIdList& buffers_used = user.user->buffers_used;
|
||||||
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
||||||
new_buffer.id);
|
new_buffer.id);
|
||||||
}
|
}
|
||||||
@ -1131,6 +1180,10 @@ string MemoryUsageTracker::ToString() const {
|
|||||||
absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
|
absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
|
||||||
buffer.unfinished_user_count, " unfinished uses\n");
|
buffer.unfinished_user_count, " unfinished uses\n");
|
||||||
}
|
}
|
||||||
|
absl::StrAppend(&output, " Outputs:\n");
|
||||||
|
for (BufferId buffer_id : item->buffers_output) {
|
||||||
|
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
|
||||||
|
}
|
||||||
absl::StrAppend(&output, " Uses:\n");
|
absl::StrAppend(&output, " Uses:\n");
|
||||||
for (BufferId buffer_id : item->buffers_used) {
|
for (BufferId buffer_id : item->buffers_used) {
|
||||||
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
|
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
|
||||||
@ -1190,12 +1243,14 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
}
|
}
|
||||||
for (const Buffer& buffer : buffers_) {
|
for (const Buffer& buffer : buffers_) {
|
||||||
int64 unfinished_uses = 0;
|
int64 unfinished_uses = 0;
|
||||||
for (Item* user : buffer.users) {
|
absl::flat_hash_set<Item*> already_counted_user;
|
||||||
const BufferIdList& used_buffers = user->buffers_used;
|
for (const ItemUse& user : buffer.users) {
|
||||||
|
const BufferIdList& used_buffers = user.user->buffers_used;
|
||||||
CHECK(absl::c_linear_search(used_buffers, buffer.id))
|
CHECK(absl::c_linear_search(used_buffers, buffer.id))
|
||||||
<< "Instruction " << user->instruction->name()
|
<< "Instruction " << user.user->instruction->name()
|
||||||
<< " used buffers is missing " << buffer.ToString();
|
<< " used buffers is missing " << buffer.ToString();
|
||||||
if (!IsFinished(user)) {
|
if (!IsFinished(user.user) &&
|
||||||
|
already_counted_user.insert(user.user).second) {
|
||||||
unfinished_uses++;
|
unfinished_uses++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1397,8 +1452,8 @@ MemoryUsageTracker::PickRematerializationCandidates(
|
|||||||
bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
|
bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
|
||||||
for (BufferId buffer_id : item->buffers_defined) {
|
for (BufferId buffer_id : item->buffers_defined) {
|
||||||
const Buffer& buffer = buffers_.at(buffer_id);
|
const Buffer& buffer = buffers_.at(buffer_id);
|
||||||
for (Item* user : buffer.users) {
|
for (const ItemUse& user : buffer.users) {
|
||||||
if (!user->placed) {
|
if (!user.user->placed) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1406,6 +1461,17 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
|
||||||
|
UsesList combined_users;
|
||||||
|
for (BufferId buffer_id : item->buffers_defined) {
|
||||||
|
const Buffer& buffer = buffers_.at(buffer_id);
|
||||||
|
for (const ItemUse& user : buffer.users) {
|
||||||
|
combined_users.push_back(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return combined_users;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<int64> RematerializeInstructions(
|
StatusOr<int64> RematerializeInstructions(
|
||||||
MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
|
MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
|
||||||
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
|
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
|
||||||
@ -1443,18 +1509,30 @@ StatusOr<int64> RematerializeInstructions(
|
|||||||
Item* remat_item = instruction_list->CreateItem(remat);
|
Item* remat_item = instruction_list->CreateItem(remat);
|
||||||
|
|
||||||
// Replace each remaining use of 'best' with the rematerialization.
|
// Replace each remaining use of 'best' with the rematerialization.
|
||||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
absl::InlinedVector<Item*, 4> bitcasts;
|
||||||
for (HloInstruction* user : best_users_copy) {
|
for (auto& user : memory_tracker->GetItemUses(best_item)) {
|
||||||
if (!memory_tracker->IsPlaced(user)) {
|
if (!memory_tracker->IsPlaced(user.user->instruction)) {
|
||||||
VLOG(2) << " Replacing use of " << best->name() << " in "
|
VLOG(2) << " Replacing use of " << best->name() << " in "
|
||||||
<< user->name() << " with " << remat->name();
|
<< user.user->instruction->name() << " with " << remat->name();
|
||||||
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
|
const int64 op_idx = user.operand_number;
|
||||||
|
auto* remat_use = remat;
|
||||||
|
if (user.user->instruction->operand(op_idx)->shape() !=
|
||||||
|
remat->shape()) {
|
||||||
|
remat_use = computation->AddInstruction(HloInstruction::CreateUnary(
|
||||||
|
user.user->instruction->operand(op_idx)->shape(),
|
||||||
|
HloOpcode::kBitcast, remat));
|
||||||
|
bitcasts.push_back(instruction_list->CreateItem(remat_use));
|
||||||
|
bitcasts.back()->buffers_output = remat_item->buffers_defined;
|
||||||
|
bitcasts.back()->buffers_used = remat_item->buffers_defined;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account for the rematerialization in the memory tracker.
|
// Account for the rematerialization in the memory tracker.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
|
||||||
memory_tracker->AddRematerializedInstruction(best_item, remat_item));
|
best_item, remat_item, absl::MakeSpan(bitcasts)));
|
||||||
|
|
||||||
// Insert rematerialized instruction right before the earliest unplaced
|
// Insert rematerialized instruction right before the earliest unplaced
|
||||||
// use of the instruction *and* the earliest unplaced last use of any
|
// use of the instruction *and* the earliest unplaced last use of any
|
||||||
@ -1463,7 +1541,14 @@ StatusOr<int64> RematerializeInstructions(
|
|||||||
// this could increase memory usage.
|
// this could increase memory usage.
|
||||||
ItemList place_before;
|
ItemList place_before;
|
||||||
for (auto user : remat->users()) {
|
for (auto user : remat->users()) {
|
||||||
place_before.push_back(instruction_list->GetItem(user));
|
if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) {
|
||||||
|
place_before.push_back(instruction_list->GetItem(user));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto* bitcast : bitcasts) {
|
||||||
|
for (auto user : bitcast->instruction->users()) {
|
||||||
|
place_before.push_back(instruction_list->GetItem(user));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto* operand : remat->operands()) {
|
for (auto* operand : remat->operands()) {
|
||||||
for (auto* operand_user : operand->users()) {
|
for (auto* operand_user : operand->users()) {
|
||||||
@ -1486,12 +1571,25 @@ StatusOr<int64> RematerializeInstructions(
|
|||||||
}
|
}
|
||||||
instruction_list->InsertBeforeInstructions(remat_item, place_before);
|
instruction_list->InsertBeforeInstructions(remat_item, place_before);
|
||||||
|
|
||||||
|
for (auto* bitcast : bitcasts) {
|
||||||
|
instruction_list->InsertBeforeInstructions(bitcast, place_before);
|
||||||
|
}
|
||||||
|
// Helper function that looks through bitcasts when determining if there
|
||||||
|
// is an active user for an HloInstruction.
|
||||||
|
std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
|
||||||
|
for (auto* u : i->users()) {
|
||||||
|
if (u->opcode() != HloOpcode::kBitcast || !uses_empty(u)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
// If the rematerialized instruction is dead then rematerialization is
|
// If the rematerialized instruction is dead then rematerialization is
|
||||||
// essentially a move. Don't delete the instruction now because we don't
|
// essentially a move. Don't delete the instruction now because we don't
|
||||||
// want duplicate HloInstruction* values during the course of the
|
// want duplicate HloInstruction* values during the course of the
|
||||||
// transformation because we keep maps with HloInstruction* values as
|
// transformation because we keep maps with HloInstruction* values as
|
||||||
// keys.
|
// keys.
|
||||||
if (best->users().empty()) {
|
if (uses_empty(best)) {
|
||||||
VLOG(2) << best->name() << " is now dead";
|
VLOG(2) << best->name() << " is now dead";
|
||||||
if (ContainsKey(*remat_move_instructions, best)) {
|
if (ContainsKey(*remat_move_instructions, best)) {
|
||||||
// Previously, 'best' was a rematerialization which killed the
|
// Previously, 'best' was a rematerialization which killed the
|
||||||
@ -1501,8 +1599,12 @@ StatusOr<int64> RematerializeInstructions(
|
|||||||
instruction_list->Denylist(remat);
|
instruction_list->Denylist(remat);
|
||||||
}
|
}
|
||||||
remat_move_instructions->insert(remat);
|
remat_move_instructions->insert(remat);
|
||||||
|
net_instructions_added += bitcasts.size();
|
||||||
} else {
|
} else {
|
||||||
net_instructions_added++;
|
net_instructions_added += bitcasts.size() + 1;
|
||||||
|
}
|
||||||
|
for (auto* bitcast : bitcasts) {
|
||||||
|
instruction_list->Denylist(bitcast->instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
VLOG(1) << "Rematerializing instructions ["
|
VLOG(1) << "Rematerializing instructions ["
|
||||||
|
|||||||
@ -748,6 +748,105 @@ ENTRY %entry {
|
|||||||
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
|
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test rematerialization of values through bitcasts
|
||||||
|
// Its expected that the broadcast gets rematerialized
|
||||||
|
TEST_F(HloRematerializationTest, ThroughBitcastRemat) {
|
||||||
|
const string& hlo_string = R"(
|
||||||
|
HloModule fusion, is_scheduled=true
|
||||||
|
|
||||||
|
ENTRY %mycomp (param: f32[1]) -> f32[1] {
|
||||||
|
%param = f32[1]{0} parameter(0)
|
||||||
|
%reshape = f32[] reshape(f32[1]{0} %param)
|
||||||
|
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
|
||||||
|
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
|
||||||
|
%negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %broadcast)
|
||||||
|
%concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate, f32[1024,1]{1,0} %negate), dimensions={0}
|
||||||
|
%slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate), slice={[0:1], [0:1]}
|
||||||
|
%bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
|
||||||
|
%concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast, f32[1]{0} %bitcast.1), dimensions={0}
|
||||||
|
ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto* computation = module->entry_computation();
|
||||||
|
// Find and save the original broadcast instruction which should be
|
||||||
|
// rematerialized.
|
||||||
|
const HloInstruction* slice = computation->root_instruction();
|
||||||
|
ASSERT_THAT(slice,
|
||||||
|
op::Slice(op::Concatenate(op::Bitcast(op::Broadcast(_)), _)));
|
||||||
|
const HloInstruction* concat = slice->operand(0);
|
||||||
|
const HloInstruction* bcast = concat->operand(0)->operand(0);
|
||||||
|
|
||||||
|
// Computation requires 16KB without rematerialization, but uses only 12KB
|
||||||
|
// with rematerialization so pick a memory limit between these values (14KB).
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
RunHloRematerialization(
|
||||||
|
/*memory_limit_bytes=*/14 * 1024, module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
|
// Root should not have changed.
|
||||||
|
EXPECT_EQ(computation->root_instruction(), slice);
|
||||||
|
|
||||||
|
// The bitcast for the rematerialized broadcast
|
||||||
|
const HloInstruction* remat_bitcast = concat->operand(0);
|
||||||
|
// The broadcast should have been rematerialized.
|
||||||
|
const HloInstruction* remat_broadcast = remat_bitcast->operand(0);
|
||||||
|
|
||||||
|
EXPECT_THAT(remat_broadcast, op::Broadcast(::testing::Ne(bcast)));
|
||||||
|
|
||||||
|
// The rematerialized broadcast should be immediately before its bitcast
|
||||||
|
// and the bitcast before the concatenate in the sequence.
|
||||||
|
EXPECT_EQ(module->schedule()
|
||||||
|
.sequence(computation)
|
||||||
|
.instructions()[computation->instruction_count() - 2],
|
||||||
|
concat);
|
||||||
|
EXPECT_EQ(module->schedule()
|
||||||
|
.sequence(computation)
|
||||||
|
.instructions()[computation->instruction_count() - 3],
|
||||||
|
remat_bitcast);
|
||||||
|
EXPECT_EQ(module->schedule()
|
||||||
|
.sequence(computation)
|
||||||
|
.instructions()[computation->instruction_count() - 4],
|
||||||
|
remat_broadcast);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the "deny list for move remats" engages when we rematerialize
|
||||||
|
// through bitcasts.
|
||||||
|
TEST_F(HloRematerializationTest, ThroughBitcastRematInfiniteLoop) {
|
||||||
|
const string& hlo_string = R"(
|
||||||
|
HloModule fusion, is_scheduled=true
|
||||||
|
|
||||||
|
ENTRY %mycomp (param: f32[1]) -> f32[1024] {
|
||||||
|
%param = f32[1]{0} parameter(0)
|
||||||
|
%reshape = f32[] reshape(f32[1]{0} %param)
|
||||||
|
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
|
||||||
|
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
|
||||||
|
%broadcast2 = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
|
||||||
|
%bitcast2 = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast2)
|
||||||
|
ROOT %add = f32[1024]{0} add(f32[1024]{0} %bitcast, f32[1024]{0} %bitcast2)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto* computation = module->entry_computation();
|
||||||
|
// Find and save the original broadcasts instruction which should be
|
||||||
|
// rematerialized.
|
||||||
|
const HloInstruction* add = computation->root_instruction();
|
||||||
|
// Run with a low rematerialization limit that cannot be satisfied to make
|
||||||
|
// sure that we don't get stuck in a loop trying to lower it.
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
RunHloRematerialization(
|
||||||
|
/*memory_limit_bytes=*/1024, module.get()));
|
||||||
|
ASSERT_THAT(add, op::Add(op::Bitcast(op::Broadcast(_)),
|
||||||
|
op::Bitcast(op::Broadcast(_))));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user