Improve performance of compilation by ~8% by speeding up the

hlo rematerialization pass.

Changes:
. Wrap each HloInstruction* inside an Item structure that keeps
  associated data.  This allows us to get rid of a bunch of
  hash tables indexed by HloInstruction*.
* Switch to an intrusive linked list (instead of std::list) so
  that we can avoid a hash table that maps to std::list::iterator.
* Use inlined vector in a few places.
PiperOrigin-RevId: 163848365
This commit is contained in:
A. Unique TensorFlower 2017-08-01 10:29:30 -07:00 committed by TensorFlower Gardener
parent 6d77a01293
commit 66f1485424

View File

@ -47,6 +47,12 @@ namespace xla {
namespace {
// Potential optimizations:
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
// of candidates.
// . Cache IsRematerializable in Item? Only correct if control
// predecessors and successors don't change.
// Returns true if the given instruction is rematerializable.
bool IsRematerializable(const HloInstruction* instruction) {
// Conservatively, don't rematerialize instruction with control
@ -79,126 +85,202 @@ bool IsRematerializable(const HloInstruction* instruction) {
}
}
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
// We wrap HloInstruction* with an Item that holds auxiliary
// per-instruction state.
struct Item {
HloInstruction* instruction;
// True once the instruction is marked as placed (when BeginInstruction
// has been called for this instruction).
bool placed = false;
// To avoid an infinite loop rematerializing the same set of
// instructions ad infinitum, keep a blacklist of instructions
// which should not be rematerialized.
bool blacklisted = false;
// The buffers defined by this instruction.
BufferIdList buffers_defined;
// The buffers used by this instruction.
BufferIdList buffers_used;
private:
friend class InstructionList;
// Items are arranged in a doubly linked list.
Item* next;
Item* prev;
// List is ordered by position, which can however be duplicated as
// new instructions are inserted. See InsertBeforeInstructions
// comment for details.
int64 position;
};
using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
// Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements.
class InstructionList {
public:
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
int64 position = 0;
Item* last = nullptr;
for (const HloInstruction* inst : order) {
instructions_.push_back(const_cast<HloInstruction*>(inst));
instruction_iterators_.insert({const_cast<HloInstruction*>(inst),
std::next(instructions_.end(), -1)});
// Add a new item to the linked list.
Item* item = new Item;
item->next = nullptr;
item->prev = last;
if (last == nullptr) {
first_ = item;
} else {
last->next = item;
}
last = item;
// Initially position numbers are uniquely assigned in order. Later as
// instructions are added with InsertBefore* methods, some instructions
// may have duplicate position numbers, but the values will be guaranteed
// to be monotonically increasing through the list, and so is still useful
// for quickly(-ish) determining the order of arbitrary instructions in
// the list.
position_number_[inst] = position;
first_at_position_[position] = inst;
item->instruction = const_cast<HloInstruction*>(inst);
item->position = position;
position++;
item_map_[inst] = item;
}
}
// Returns the list of instructions.
const std::list<HloInstruction*>& instructions() const {
return instructions_;
~InstructionList() {
for (Item* item = first_; item != nullptr;) {
Item* next = item->next;
delete item;
item = next;
}
}
// Insert instruction 'to_insert' immediately before instruction 'before' in
// the list.
void InsertBefore(HloInstruction* to_insert, HloInstruction* before) {
VLOG(3) << "InsertBefore: " << to_insert->name() << " before "
<< before->name();
auto it = instruction_iterators_.find(before);
CHECK(it != instruction_iterators_.end());
instruction_iterators_.insert(
{to_insert, instructions_.insert(it->second, to_insert)});
// Assign the same position number to the newly added instruction as
// 'before'. This guarantees monotonicity of the position numbers, but not
// uniqueness.
int64 pos = position_number_.at(before);
position_number_[to_insert] = pos;
if (first_at_position_.at(pos) == before) {
first_at_position_[pos] = to_insert;
}
size_t size() const { return item_map_.size(); }
// For ordered iteration over items.
// for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
Item* first() const { return first_; }
Item* next(Item* item) const { return item->next; }
// Creates an Item for the given instruction, but doesn't add it to the list.
// (Use InsertBeforeInstructions to add the Item to the list.)
Item* CreateItem(HloInstruction* inst) {
Item* item = new Item;
item->instruction = inst;
CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice";
return item;
}
// Return the Item corresponding to inst.
Item* GetItem(const HloInstruction* inst) const {
auto iter = item_map_.find(inst);
CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
return iter->second;
}
// Insert instruction 'to_insert' immediately before the earliest instruction
// in 'before_instructions'.
//
// Each instruction gets a non-decreasing ordinal number. We use this to let
// InsertBeforeInstructions quickly insert an instruction before the earliest
// instruction in a set of instructions. If position_number_[a] <
// position_number_[b] then 'a' comes before 'b' in the list. If the position
// numbers are the same then nothing can be said about their order without
// examining the list.
//
// On object construction this ordinal is precisely the instruction's index
// in the list. Later, instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
void InsertBeforeInstructions(
HloInstruction* to_insert,
tensorflow::gtl::ArraySlice<HloInstruction*> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {"
<< tensorflow::str_util::Join(
before_instructions, ", ",
[](string* out, HloInstruction* inst) {
tensorflow::strings::StrAppend(out, inst->name());
})
Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
<< " before {"
<< tensorflow::str_util::Join(before_instructions, ", ",
[](string* out, Item* item) {
tensorflow::strings::StrAppend(
out, item->instruction->name());
})
<< "}";
// Find the minimal position number of any instruction in
// 'before_instructions'.
CHECK(!before_instructions.empty());
int64 min_position_number = std::numeric_limits<int64>::max();
for (const HloInstruction* instruction : before_instructions) {
min_position_number =
std::min(min_position_number, position_number_.at(instruction));
Item* min_position_item = nullptr;
for (Item* item : before_instructions) {
if (min_position_item == nullptr ||
item->position < min_position_item->position) {
min_position_item = item;
}
}
// Because more than one instruction in 'before_instructions' may have a
// position number of 'min_position_number', find the first such instruction
// with position number 'min_position_number'.
for (auto it = instruction_iterators_.at(
first_at_position_.at(min_position_number));
it != instructions_.end() &&
position_number_.at(*it) == min_position_number;
++it) {
if (std::find(before_instructions.begin(), before_instructions.end(),
*it) != before_instructions.end()) {
return InsertBefore(to_insert, *it);
}
// First find first instruction with the min position.
while (min_position_item->prev != nullptr &&
min_position_item->position == min_position_item->prev->position) {
min_position_item = min_position_item->prev;
}
LOG(FATAL) << "Expected to find instruction in before_instructions with "
"position number "
<< min_position_number;
// Now scan forwards until we find one of the before_instructions.
while (std::find(before_instructions.begin(), before_instructions.end(),
min_position_item) == before_instructions.end()) {
min_position_item = min_position_item->next;
}
return InsertBefore(to_insert, min_position_item);
}
void Blacklist(const HloInstruction* inst) {
GetItem(inst)->blacklisted = true;
}
private:
// List of instructions.
std::list<HloInstruction*> instructions_;
// Insert instruction 'item' immediately before 'before' in the list.
void InsertBefore(Item* item, Item* before) {
VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
<< before->instruction->name();
// Insert new item into linked list.
item->prev = before->prev;
item->next = before;
before->prev = item;
if (item->prev != nullptr) {
item->prev->next = item;
} else {
first_ = item;
}
// Iterators for each instruction in the list.
tensorflow::gtl::FlatMap<const HloInstruction*,
std::list<HloInstruction*>::iterator>
instruction_iterators_;
// Assign the same position number to the newly added instruction as
// 'before'. This guarantees monotonicity of the position numbers, but not
// uniqueness.
item->position = before->position;
}
// A number assigned to each instruction which increases monotonically through
// 'instructions_'. Used to facilitate fast insertion of an instruction before
// the earliest instruction in a set of instructions
// (InsertBeforeInstructions) by enabling fast-ish ordering queries between
// instructions. If position_number_[a] < position_number_[b] then 'a' comes
// before 'b' in the list. If the position numbers are the same then nothing
// can be said about their order without examining the list.
//
// On object construction this value is precisely the instruction's ordinal
// position in the list. Instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
tensorflow::gtl::FlatMap<const HloInstruction*, int64> position_number_;
Item* first_;
// The first instruction in the list assigned a particular position number.
tensorflow::gtl::FlatMap<int64, const HloInstruction*> first_at_position_;
// Item for each instruction.
tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
};
// Return the HloInstructions which use the given LogicalBuffer. Sets
// Return the items which use the given LogicalBuffer. Sets
// has_indirect_users to whether any of the uses is indirect. A use is indirect
// if the instruction defining logical_buffer is not an operand of the use. This
// can happen via buffer aliasing (eg, tuples).
std::vector<const HloInstruction*> GetUsers(
const LogicalBuffer* logical_buffer,
const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) {
std::vector<const HloInstruction*> users;
ItemList GetUsers(const InstructionList& instruction_list,
const LogicalBuffer* logical_buffer,
const TuplePointsToAnalysis& points_to_analysis,
bool* has_indirect_users) {
ItemList users;
// To identify uses iterate through all HloInstruction users of the
// BufferAliases of the logical buffer.
*has_indirect_users = false;
@ -219,8 +301,9 @@ std::vector<const HloInstruction*> GetUsers(
}
// A buffer may be used by the instruction via more than one alias. For
// example, a buffer which appears in more than one element of a tuple.
if (std::find(users.begin(), users.end(), user) == users.end()) {
users.push_back(user);
Item* user_item = instruction_list.GetItem(user);
if (std::find(users.begin(), users.end(), user_item) == users.end()) {
users.push_back(user_item);
}
}
}
@ -244,7 +327,7 @@ class MemoryUsageTracker {
// EndInstruction) to accurately model memory usage. At BeginInstruction the
// memory for the output value(s) of the current instruction is allocated. At
// EndInstruction memory for dead operand(s) is freed.
Status BeginInstruction(const HloInstruction* instruction);
Status BeginInstruction(Item* item);
// Finishes the placement of the current instruction. This frees any dead
// operands or dead result of the instruction. This must be called after
@ -253,40 +336,31 @@ class MemoryUsageTracker {
// Returns the number of bytes that the current memory usage will be reduced
// if the given instruction is rematerialized.
int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const;
int64 MemoryReducedIfRematerialized(Item* item) const;
// Adjusts memory usage to account for the rematerialization of
// original_instruction for all remaining unplaced uses. The rematerialization
// is remat_instruction. This method should be called after the HLO graph has
// original_item for all remaining unplaced uses. The rematerialization
// is remat_item. This method should be called after the HLO graph has
// been transformed (rematerialization instruction created and connected to
// uses).
Status AddRematerializedInstruction(HloInstruction* original_instruction,
HloInstruction* remat_instruction);
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
// Returns whether the given instruction has been placed (BeginInstruction
// has been called with 'instruction' as the argument).
bool IsPlaced(const HloInstruction* instruction) const {
return ContainsKey(placed_instructions_, instruction);
return instruction_list_.GetItem(instruction)->placed;
}
// Returns the current memory usage. This is the sum of sizes of all live
// values.
int64 memory_usage() const { return memory_usage_; }
// Returns the current instruction being placed.
const HloInstruction* in_progress_instruction() const {
return in_progress_instruction_;
}
// Check invariants of the data structure. This is expensive to call.
bool Check() const;
string ToString() const;
private:
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
// A Buffer represents a single LogicalBuffer in the computation including
// various metadata useful for tracking liveness of the value. A LogicalBuffer
// is not used directly because the HLO graph is transformed and
@ -298,7 +372,7 @@ class MemoryUsageTracker {
const BufferId id;
// The instruction which defines this buffer.
const HloInstruction* defining_instruction;
Item* defining_instruction;
// The materialized size of the buffer in bytes.
const int64 size;
@ -312,16 +386,17 @@ class MemoryUsageTracker {
bool has_indirect_uses;
// The instructions which use this buffer.
std::vector<const HloInstruction*> users;
ItemList users;
// The number of users (HloInstructions) of this buffer which have not yet
// been placed in the sequence.
int64 unfinished_user_count;
string ToString() const {
return tensorflow::strings::StrCat("Buffer ", id, " (defined by ",
defining_instruction->name(),
", size ", size, " bytes)");
return tensorflow::strings::StrCat(
"Buffer ", id, " (defined by ",
defining_instruction->instruction->name(), ", size ", size,
" bytes)");
}
};
@ -333,25 +408,24 @@ class MemoryUsageTracker {
const HloRematerialization::ShapeSizeFunction& size_function,
bool live_out) {
bool has_indirect_uses = false;
std::vector<const HloInstruction*> users =
GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses);
return NewBuffer(logical_buffer->instruction(),
ItemList users = GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &has_indirect_uses);
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
size_function(logical_buffer->shape()), std::move(users),
live_out, has_indirect_uses);
}
// Create a new buffer representing a rematerialization of given buffer for
// the given uses.
Buffer& RematerializeBuffer(
const Buffer& original_buffer, const HloInstruction* remat_instruction,
std::vector<const HloInstruction*>&& rematerialized_uses) {
CHECK(IsPlaced(original_buffer.defining_instruction));
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
ItemList&& rematerialized_uses) {
CHECK(original_buffer.defining_instruction->placed);
CHECK(!original_buffer.has_indirect_uses);
CHECK(!original_buffer.live_out);
for (const HloInstruction* use : rematerialized_uses) {
CHECK(!IsPlaced(use));
for (Item* use : rematerialized_uses) {
CHECK(!use->placed);
}
return NewBuffer(remat_instruction, original_buffer.size,
return NewBuffer(remat_item, original_buffer.size,
std::move(rematerialized_uses), /*live_out=*/false,
/*has_indirect_uses=*/false);
}
@ -362,7 +436,7 @@ class MemoryUsageTracker {
// different computation.
int64 AllocatedSize(BufferId buffer_id) const {
const Buffer& buffer = buffers_.at(buffer_id);
HloOpcode def_opcode = buffer.defining_instruction->opcode();
HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
return 0;
} else {
@ -372,18 +446,17 @@ class MemoryUsageTracker {
// Returns true if BeginInstruction and EndInstruction has been called for the
// given instruction.
bool IsFinished(const HloInstruction* instruction) const {
return IsPlaced(instruction) && instruction != in_progress_instruction_;
bool IsFinished(Item* item) const {
return item->placed && item != in_progress_item_;
}
// Returns whether the given buffer is being used by the in-progress
// instruction.
bool IsInUse(BufferId buffer_id) const {
if (in_progress_instruction_ == nullptr) {
if (in_progress_item_ == nullptr) {
return false;
}
const std::vector<BufferId>& in_progress_uses =
buffers_used_by_instruction_.at(in_progress_instruction_);
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
buffer_id) != in_progress_uses.end();
}
@ -392,14 +465,13 @@ class MemoryUsageTracker {
// point.
bool IsCurrentlyLive(BufferId buffer_id) const {
const Buffer& buffer = buffers_[buffer_id];
return (IsPlaced(buffer.defining_instruction) &&
return (buffer.defining_instruction->placed &&
buffer.unfinished_user_count > 0);
}
// Create a new buffer, add it to buffers_, and return a reference.
Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size,
std::vector<const HloInstruction*>&& users, bool live_out,
bool has_indirect_uses) {
Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
bool live_out, bool has_indirect_uses) {
int buffer_id = buffers_.size();
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
has_indirect_uses, users,
@ -419,19 +491,7 @@ class MemoryUsageTracker {
// The instruction currently being placed. This value is non-null only
// between the calling of BeginInstruction and EndInstruction.
const HloInstruction* in_progress_instruction_ = nullptr;
// The buffers defined by each instruction.
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
buffers_defined_by_instruction_;
// The buffers used by each instruction.
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
buffers_used_by_instruction_;
// The set of instructions which have been placed. That is, BeginInstruction
// has been called with the instruction as an argument.
tensorflow::gtl::FlatSet<const HloInstruction*> placed_instructions_;
Item* in_progress_item_ = nullptr;
// All buffers in the computation.
std::vector<Buffer> buffers_;
@ -443,22 +503,15 @@ MemoryUsageTracker::MemoryUsageTracker(
const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list)
: computation_(computation), instruction_list_(instruction_list) {
// Iterate through all LogicalBuffers in the computation and gather the
// instructions which define them in buffers_defined_by_instruction_ and the
// instructions which use them in buffers_used_by_instruction_.
for (auto& instruction : computation_->instructions()) {
// Initialize empty vectors for defs and uses of each instruction.
buffers_used_by_instruction_[instruction.get()];
buffers_defined_by_instruction_[instruction.get()];
}
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
logical_buffer_to_buffer_id;
for (const HloInstruction* instruction : instruction_list_.instructions()) {
for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* const instruction = item->instruction;
for (const LogicalBuffer* logical_buffer :
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
Buffer* buffer;
@ -481,22 +534,22 @@ MemoryUsageTracker::MemoryUsageTracker(
// Add users of while to Buffer users.
bool unused;
for (const HloInstruction* user :
GetUsers(logical_buffer, points_to_analysis, &unused)) {
if (std::find(buffer->users.begin(), buffer->users.end(), user) ==
buffer->users.end()) {
buffer->users.push_back(user);
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &unused)) {
if (std::find(buffer->users.begin(), buffer->users.end(),
user_item) == buffer->users.end()) {
buffer->users.push_back(user_item);
buffer->unfinished_user_count++;
buffers_used_by_instruction_.at(user).push_back(buffer->id);
user_item->buffers_used.push_back(buffer->id);
}
}
} else {
buffer = &CreateBufferFromLogicalBuffer(
logical_buffer, points_to_analysis, size_function,
ContainsKey(live_out_set, logical_buffer));
buffers_defined_by_instruction_.at(instruction).push_back(buffer->id);
for (const HloInstruction* user : buffer->users) {
buffers_used_by_instruction_.at(user).push_back(buffer->id);
item->buffers_defined.push_back(buffer->id);
for (Item* user : buffer->users) {
user->buffers_used.push_back(buffer->id);
}
}
@ -507,15 +560,16 @@ MemoryUsageTracker::MemoryUsageTracker(
DCHECK(Check());
}
Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
Status MemoryUsageTracker::BeginInstruction(Item* item) {
const HloInstruction* instruction = item->instruction;
VLOG(3) << "BeginInstruction " << instruction->name();
TF_RET_CHECK(in_progress_instruction_ == nullptr);
in_progress_instruction_ = instruction;
TF_RET_CHECK(in_progress_item_ == nullptr);
in_progress_item_ = item;
placed_instructions_.insert(in_progress_instruction_);
item->placed = true;
// All buffers defined by this instruction need memory.
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
for (BufferId buffer_id : item->buffers_defined) {
VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
<< " is now live.";
memory_usage_ += AllocatedSize(buffer_id);
@ -532,11 +586,10 @@ Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
}
Status MemoryUsageTracker::EndInstruction() {
TF_RET_CHECK(in_progress_instruction_ != nullptr);
VLOG(3) << "EndInstruction " << in_progress_instruction_->name();
TF_RET_CHECK(in_progress_item_ != nullptr);
VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
for (BufferId buffer_id :
buffers_used_by_instruction_.at(in_progress_instruction_)) {
for (BufferId buffer_id : in_progress_item_->buffers_used) {
Buffer& buffer = buffers_.at(buffer_id);
buffer.unfinished_user_count--;
CHECK_GE(buffer.unfinished_user_count, 0)
@ -551,8 +604,7 @@ Status MemoryUsageTracker::EndInstruction() {
// If any buffer defined by this instruction has no uses, then memory can be
// reclaimed immediately.
for (BufferId buffer_id :
buffers_defined_by_instruction_.at(in_progress_instruction_)) {
for (BufferId buffer_id : in_progress_item_->buffers_defined) {
const Buffer& buffer = buffers_.at(buffer_id);
if (buffer.unfinished_user_count == 0) {
VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
@ -561,7 +613,7 @@ Status MemoryUsageTracker::EndInstruction() {
}
}
in_progress_instruction_ = nullptr;
in_progress_item_ = nullptr;
VLOG(3) << " memory usage = " << memory_usage_;
VLOG(10) << ToString();
@ -571,10 +623,9 @@ Status MemoryUsageTracker::EndInstruction() {
return Status::OK();
}
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
const HloInstruction* instruction) const {
CHECK_NE(in_progress_instruction_, nullptr);
if (!IsPlaced(instruction) || instruction == in_progress_instruction_) {
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
CHECK_NE(in_progress_item_, nullptr);
if (!item->placed || item == in_progress_item_) {
return 0;
}
@ -589,7 +640,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
// be live at this program point, so initially set memory_reduced to the
// size of its defined values.
int64 memory_reduced = 0;
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
for (BufferId buffer_id : item->buffers_defined) {
// Avoid rematerializing instructions with indirect uses as it is difficult
// to reason about liveness after rematerializing the instruction.
// TODO(b/37714814): Consider rematerialzing instructions with indirect
@ -605,7 +656,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
// Account for any logical buffers whose live range must be extended across
// this program point.
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) {
for (BufferId buffer_id : item->buffers_used) {
if (!IsCurrentlyLive(buffer_id)) {
// This logical buffer is used by 'instruction' but is not live at this
// program point. Rematerializing 'instruction' will extend the buffer's
@ -617,28 +668,23 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
return memory_reduced;
}
Status MemoryUsageTracker::AddRematerializedInstruction(
HloInstruction* original_instruction, HloInstruction* remat_instruction) {
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
Item* remat_item) {
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
<< original_instruction->name()
<< ", remat_instruction = " << remat_instruction->name();
<< original_item->instruction->name()
<< ", remat_instruction = " << remat_item->instruction->name();
TF_RET_CHECK(in_progress_instruction_ != nullptr);
TF_RET_CHECK(IsPlaced(original_instruction));
TF_RET_CHECK(!IsPlaced(remat_instruction));
CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction));
CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction));
TF_RET_CHECK(in_progress_item_ != nullptr);
TF_RET_CHECK(original_item->placed);
TF_RET_CHECK(!remat_item->placed);
// Construct the list of buffers used and defined by the rematerialization.
buffers_defined_by_instruction_[remat_instruction];
buffers_used_by_instruction_[remat_instruction] =
buffers_used_by_instruction_.at(original_instruction);
remat_item->buffers_used = original_item->buffers_used;
// Account for the additional buffer uses created by the new rematerialization
// instruction. Update memory usage if the rematerialization makes a dead
// buffer live again.
for (BufferId buffer_id :
buffers_used_by_instruction_.at(original_instruction)) {
for (BufferId buffer_id : original_item->buffers_used) {
Buffer& buffer = buffers_.at(buffer_id);
if (buffer.unfinished_user_count == 0) {
// Buffer used by this instruction was dead, now is alive.
@ -646,20 +692,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
}
buffer.unfinished_user_count++;
buffer.users.push_back(remat_instruction);
buffer.users.push_back(remat_item);
}
// Create a new set of Buffers defined by the new rematerialization
// instruction. Update the internal data structures and memory use to account
// for them.
for (BufferId old_buffer_id :
buffers_defined_by_instruction_.at(original_instruction)) {
for (BufferId old_buffer_id : original_item->buffers_defined) {
Buffer& old_buffer = buffers_.at(old_buffer_id);
std::vector<const HloInstruction*> placed_users;
std::vector<const HloInstruction*> unplaced_users;
for (const HloInstruction* user : old_buffer.users) {
if (IsPlaced(user)) {
ItemList placed_users;
ItemList unplaced_users;
for (Item* user : old_buffer.users) {
if (user->placed) {
CHECK(IsFinished(user));
placed_users.push_back(user);
} else {
@ -672,14 +717,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
// Buffer is now dead.
memory_usage_ -= AllocatedSize(old_buffer.id);
Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction,
std::move(unplaced_users));
Buffer& new_buffer =
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
buffers_defined_by_instruction_.at(remat_instruction)
.push_back(new_buffer.id);
for (const HloInstruction* user : new_buffer.users) {
std::vector<BufferId>& buffers_used =
buffers_used_by_instruction_.at(user);
remat_item->buffers_defined.push_back(new_buffer.id);
for (Item* user : new_buffer.users) {
BufferIdList& buffers_used = user->buffers_used;
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
new_buffer.id);
}
@ -699,13 +742,14 @@ string MemoryUsageTracker::ToString() const {
tensorflow::strings::StrAppend(
&output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
memory_usage(), " bytes)");
for (const HloInstruction* instruction : instruction_list_.instructions()) {
string inprogress =
instruction == in_progress_instruction_ ? " in-progress" : "";
string placed = IsPlaced(instruction) ? " placed" : "";
for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* instruction = item->instruction;
string inprogress = item == in_progress_item_ ? " in-progress" : "";
string placed = item->placed ? " placed" : "";
tensorflow::strings::StrAppend(&output, " ", instruction->name(),
inprogress, placed, "\n Defines:\n");
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_[buffer_id];
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
@ -713,7 +757,7 @@ string MemoryUsageTracker::ToString() const {
" unfinished uses\n");
}
tensorflow::strings::StrAppend(&output, " Uses:\n");
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) {
for (BufferId buffer_id : item->buffers_used) {
tensorflow::strings::StrAppend(&output, " ",
buffers_[buffer_id].ToString(), "\n");
}
@ -722,14 +766,14 @@ string MemoryUsageTracker::ToString() const {
}
bool MemoryUsageTracker::Check() const {
auto elements_are_unique = [](const std::vector<BufferId>& vec) {
auto elements_are_unique = [](const BufferIdList& vec) {
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
};
// Verify buffers_defined_by_instruction_.
// Verify buffers_defined per instruction.
for (auto& instruction : computation_->instructions()) {
const std::vector<BufferId>& defined_buffers =
buffers_defined_by_instruction_.at(instruction.get());
const BufferIdList& defined_buffers =
instruction_list_.GetItem(instruction.get())->buffers_defined;
CHECK(elements_are_unique(defined_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique defined buffers: "
@ -740,7 +784,7 @@ bool MemoryUsageTracker::Check() const {
});
for (const Buffer& buffer : buffers_) {
if (buffer.defining_instruction == instruction.get()) {
if (buffer.defining_instruction->instruction == instruction.get()) {
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
buffer.id) != defined_buffers.end())
<< "Instruction " << instruction->name()
@ -749,10 +793,10 @@ bool MemoryUsageTracker::Check() const {
}
}
// Verify buffers_used_by_instruction_.
// Verify buffers_used per instruction.
for (auto& instruction : computation_->instructions()) {
const std::vector<BufferId>& used_buffers =
buffers_used_by_instruction_.at(instruction.get());
const BufferIdList& used_buffers =
instruction_list_.GetItem(instruction.get())->buffers_used;
CHECK(elements_are_unique(used_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique used buffers: "
@ -764,13 +808,12 @@ bool MemoryUsageTracker::Check() const {
}
for (const Buffer& buffer : buffers_) {
int64 unfinished_uses = 0;
for (const HloInstruction* user : buffer.users) {
const std::vector<BufferId>& used_buffers =
buffers_used_by_instruction_.at(user);
for (Item* user : buffer.users) {
const BufferIdList& used_buffers = user->buffers_used;
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
used_buffers.end())
<< "Instruction " << user->name() << " used buffers is missing "
<< buffer.ToString();
<< "Instruction " << user->instruction->name()
<< " used buffers is missing " << buffer.ToString();
if (!IsFinished(user)) {
unfinished_uses++;
}
@ -785,8 +828,8 @@ bool MemoryUsageTracker::Check() const {
// The while instruction reuses its input buffers as output buffers so
// don't double count its buffers if it is currently executing.
if (IsCurrentlyLive(buffer.id) &&
!(buffer.defining_instruction == in_progress_instruction_ &&
in_progress_instruction_->opcode() == HloOpcode::kWhile)) {
!(buffer.defining_instruction == in_progress_item_ &&
in_progress_item_->instruction->opcode() == HloOpcode::kWhile)) {
live_size += AllocatedSize(buffer.id);
}
}
@ -830,26 +873,26 @@ int64 RematerializationCost(const HloInstruction* instruction,
// candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found.
HloInstruction* PickRematerializationCandidate(
const MemoryUsageTracker& memory_tracker,
const InstructionList& instruction_list,
const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist,
int64 memory_limit_bytes) {
HloInstruction* best = nullptr;
Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
const InstructionList& instruction_list,
int64 memory_limit_bytes) {
Item* best_item = nullptr;
int64 best_cost = 0;
// TODO(b/35244891): This is currently quadratic in the number of HLO
// instructions.
for (HloInstruction* candidate : instruction_list.instructions()) {
if (!memory_tracker.IsPlaced(candidate)) {
// Only iterate up to the currently placed instruction as indicated by
// memory_tracker. We are trying to reduce memory usage at the placed
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
if (!item->placed) {
// Only iterate up to the currently placed instruction.
// We are trying to reduce memory usage at the placed
// instruction so rematerializing later values is of no benefit.
break;
}
HloInstruction* candidate = item->instruction;
VLOG(5) << "considering rematerialization candidate " << candidate->name();
if (ContainsKey(blacklist, candidate)) {
if (item->blacklisted) {
// Skip instructions on the blacklist to avoid infinite loops of
// rematerializing the same instruction(s) repeatedly.
VLOG(5) << "candidate " << candidate->name()
@ -864,7 +907,7 @@ HloInstruction* PickRematerializationCandidate(
}
const int64 memory_reduced =
memory_tracker.MemoryReducedIfRematerialized(candidate);
memory_tracker.MemoryReducedIfRematerialized(item);
if (memory_reduced <= 0) {
VLOG(5) << "candidate " << candidate->name()
@ -878,13 +921,13 @@ HloInstruction* PickRematerializationCandidate(
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
<< memory_reduced << ", cost per byte " << cost;
if (best == nullptr || cost < best_cost) {
if (best_item == nullptr || cost < best_cost) {
VLOG(5) << "candidate " << candidate->name() << " now best";
best = candidate;
best_item = item;
best_cost = cost;
}
}
return best;
return best_item;
}
} // namespace
@ -896,8 +939,10 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
instruction_list);
int64 peak_memory = tracker.memory_usage();
for (const HloInstruction* instruction : order) {
TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction));
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
TF_ASSIGN_OR_RETURN(int64 callee_usage,
CalledComputationsMemoryUsage(instruction));
peak_memory =
@ -939,11 +984,6 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
*points_to_analysis_, instruction_list);
bool changed = false;
// To avoid an infinite loop rematerializing the same set of instructions ad
// infinitum, keep a blacklist of instructions which should not be
// rematerialized.
tensorflow::gtl::FlatSet<const HloInstruction*> blacklist;
// If the rematerialization makes the source instruction dead, then the
// rematerialization is added to 'remat_move_instructions' (the
// rematerialization is essentially a move). If the next rematerialization of
@ -967,17 +1007,17 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// (program point) if memory_usage exceeds the specified limit then
// rematerialize HLO instructions until memory_usage is reduced.
int64 instruction_index = 0;
for (auto list_it = instruction_list.instructions().begin();
list_it != instruction_list.instructions().end(); ++list_it) {
HloInstruction* instruction = *list_it;
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
TF_ASSIGN_OR_RETURN(int64 callee_usage,
CalledComputationsMemoryUsage(instruction));
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(instruction));
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
VLOG(2) << "Program point at " << instruction->name()
<< ", memory usage = " << memory_tracker.memory_usage()
<< ", callee usage = " << callee_usage << ", [" << instruction_index
<< "/" << instruction_list.instructions().size() << "]";
<< "/" << instruction_list.size() << "]";
instruction_index++;
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
@ -987,10 +1027,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
HloInstruction* best = PickRematerializationCandidate(
memory_tracker, instruction_list, blacklist, memory_limit_bytes);
Item* best_item = PickRematerializationCandidate(
memory_tracker, instruction_list, memory_limit_bytes);
if (best == nullptr) {
if (best_item == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program "
"point "
<< instruction->name() << ". Memory usage = "
@ -999,13 +1039,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
break;
}
HloInstruction* best = best_item->instruction;
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
<< memory_tracker.MemoryReducedIfRematerialized(best) << ")";
<< memory_tracker.MemoryReducedIfRematerialized(best_item) << ")";
changed = true;
remat_count++;
HloInstruction* remat =
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
Item* remat_item = instruction_list.CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization.
std::vector<HloInstruction*> best_users_copy = best->users();
@ -1019,22 +1061,28 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR(
memory_tracker.AddRematerializedInstruction(best, remat));
memory_tracker.AddRematerializedInstruction(best_item, remat_item));
// Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any
// operands of remat. Unplaced uses of the remat's operands are included
// because we don't want to extend the live range of remat's operands as
// this could increase memory usage.
std::vector<HloInstruction*> place_before = remat->users();
ItemList place_before;
for (auto user : remat->users()) {
place_before.push_back(instruction_list.GetItem(user));
}
for (auto* operand : remat->operands()) {
for (auto* operand_user : operand->users()) {
if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) {
place_before.push_back(operand_user);
if (operand_user != remat) {
Item* operand_user_item = instruction_list.GetItem(operand_user);
if (!operand_user_item->placed) {
place_before.push_back(operand_user_item);
}
}
}
}
instruction_list.InsertBeforeInstructions(remat, place_before);
instruction_list.InsertBeforeInstructions(remat_item, place_before);
// If the rematerialized instruction is dead then rematerialization is
// essentially a move. Don't delete the instruction now because we don't
@ -1048,7 +1096,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// instruction it was a copying of. Now 'remat' is a rematerialization
// of 'best' and kills 'best'. Stop rematerializing this instruction
// to avoid an infinite loop.
blacklist.insert(remat);
instruction_list.Blacklist(remat);
}
remat_move_instructions.insert(remat);
} else {
@ -1116,10 +1164,13 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
computation_peak_memory_.at(computation) = peak_memory;
// Update order to include rematerialized instructions.
sequence->at(computation)
.assign(instruction_list.instructions().begin(),
instruction_list.instructions().end());
auto& dst = sequence->at(computation);
dst.clear();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
dst.push_back(instruction);
}
rematerialized_computations_.insert(computation);
instructions_rematerialized_ += remat_count;