Improve performance of compilation by ~8% by speeding up the

hlo rematerialization pass.

Changes:
. Wrap each HloInstruction* inside an Item structure that keeps
  associated data.  This allows us to get rid of a bunch of
  hash tables indexed by HloInstruction*.
* Switch to an intrusive linked list (instead of std::list) so
  that we can avoid a hash table that maps to std::list::iterator.
* Use inlined vector in a few places.
PiperOrigin-RevId: 163848365
This commit is contained in:
A. Unique TensorFlower 2017-08-01 10:29:30 -07:00 committed by TensorFlower Gardener
parent 6d77a01293
commit 66f1485424

View File

@ -47,6 +47,12 @@ namespace xla {
namespace { namespace {
// Potential optimizations:
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
// of candidates.
// . Cache IsRematerializable in Item? Only correct if control
// predecessors and successors don't change.
// Returns true if the given instruction is rematerializable. // Returns true if the given instruction is rematerializable.
bool IsRematerializable(const HloInstruction* instruction) { bool IsRematerializable(const HloInstruction* instruction) {
// Conservatively, don't rematerialize instruction with control // Conservatively, don't rematerialize instruction with control
@ -79,126 +85,202 @@ bool IsRematerializable(const HloInstruction* instruction) {
} }
} }
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
// We wrap HloInstruction* with an Item that holds auxiliary
// per-instruction state.
struct Item {
HloInstruction* instruction;
// True once the instruction is marked as placed (when BeginInstruction
// has been called for this instruction).
bool placed = false;
// To avoid an infinite loop rematerializing the same set of
// instructions ad infinitum, keep a blacklist of instructions
// which should not be rematerialized.
bool blacklisted = false;
// The buffers defined by this instruction.
BufferIdList buffers_defined;
// The buffers used by this instruction.
BufferIdList buffers_used;
private:
friend class InstructionList;
// Items are arranged in a doubly linked list.
Item* next;
Item* prev;
// List is ordered by position, which can however be duplicated as
// new instructions are inserted. See InsertBeforeInstructions
// comment for details.
int64 position;
};
using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
// Class which maintains an ordered list of instructions with fast insertion // Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements. // before arbitrary elements.
class InstructionList { class InstructionList {
public: public:
explicit InstructionList(const std::vector<const HloInstruction*>& order) { explicit InstructionList(const std::vector<const HloInstruction*>& order) {
int64 position = 0; int64 position = 0;
Item* last = nullptr;
for (const HloInstruction* inst : order) { for (const HloInstruction* inst : order) {
instructions_.push_back(const_cast<HloInstruction*>(inst)); // Add a new item to the linked list.
instruction_iterators_.insert({const_cast<HloInstruction*>(inst), Item* item = new Item;
std::next(instructions_.end(), -1)}); item->next = nullptr;
item->prev = last;
if (last == nullptr) {
first_ = item;
} else {
last->next = item;
}
last = item;
// Initially position numbers are uniquely assigned in order. Later as // Initially position numbers are uniquely assigned in order. Later as
// instructions are added with InsertBefore* methods, some instructions // instructions are added with InsertBefore* methods, some instructions
// may have duplicate position numbers, but the values will be guaranteed // may have duplicate position numbers, but the values will be guaranteed
// to be monotonically increasing through the list, and so is still useful // to be monotonically increasing through the list, and so is still useful
// for quickly(-ish) determining the order of arbitrary instructions in // for quickly(-ish) determining the order of arbitrary instructions in
// the list. // the list.
position_number_[inst] = position; item->instruction = const_cast<HloInstruction*>(inst);
first_at_position_[position] = inst; item->position = position;
position++; position++;
item_map_[inst] = item;
} }
} }
// Returns the list of instructions. ~InstructionList() {
const std::list<HloInstruction*>& instructions() const { for (Item* item = first_; item != nullptr;) {
return instructions_; Item* next = item->next;
delete item;
item = next;
}
} }
// Insert instruction 'to_insert' immediately before instruction 'before' in size_t size() const { return item_map_.size(); }
// the list.
void InsertBefore(HloInstruction* to_insert, HloInstruction* before) { // For ordered iteration over items.
VLOG(3) << "InsertBefore: " << to_insert->name() << " before " // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
<< before->name(); Item* first() const { return first_; }
auto it = instruction_iterators_.find(before); Item* next(Item* item) const { return item->next; }
CHECK(it != instruction_iterators_.end());
instruction_iterators_.insert( // Creates an Item for the given instruction, but doesn't add it to the list.
{to_insert, instructions_.insert(it->second, to_insert)}); // (Use InsertBeforeInstructions to add the Item to the list.)
// Assign the same position number to the newly added instruction as Item* CreateItem(HloInstruction* inst) {
// 'before'. This guarantees monotonicity of the position numbers, but not Item* item = new Item;
// uniqueness. item->instruction = inst;
int64 pos = position_number_.at(before); CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice";
position_number_[to_insert] = pos; return item;
if (first_at_position_.at(pos) == before) { }
first_at_position_[pos] = to_insert;
} // Return the Item corresponding to inst.
Item* GetItem(const HloInstruction* inst) const {
auto iter = item_map_.find(inst);
CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
return iter->second;
} }
// Insert instruction 'to_insert' immediately before the earliest instruction // Insert instruction 'to_insert' immediately before the earliest instruction
// in 'before_instructions'. // in 'before_instructions'.
//
// Each instruction gets a non-decreasing ordinal number. We use this to let
// InsertBeforeInstructions quickly insert an instruction before the earliest
// instruction in a set of instructions. If position_number_[a] <
// position_number_[b] then 'a' comes before 'b' in the list. If the position
// numbers are the same then nothing can be said about their order without
// examining the list.
//
// On object construction this ordinal is precisely the instruction's index
// in the list. Later, instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
void InsertBeforeInstructions( void InsertBeforeInstructions(
HloInstruction* to_insert, Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
tensorflow::gtl::ArraySlice<HloInstruction*> before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {" << " before {"
<< tensorflow::str_util::Join( << tensorflow::str_util::Join(before_instructions, ", ",
before_instructions, ", ", [](string* out, Item* item) {
[](string* out, HloInstruction* inst) { tensorflow::strings::StrAppend(
tensorflow::strings::StrAppend(out, inst->name()); out, item->instruction->name());
}) })
<< "}"; << "}";
// Find the minimal position number of any instruction in // Find the minimal position number of any instruction in
// 'before_instructions'. // 'before_instructions'.
CHECK(!before_instructions.empty()); CHECK(!before_instructions.empty());
int64 min_position_number = std::numeric_limits<int64>::max(); Item* min_position_item = nullptr;
for (const HloInstruction* instruction : before_instructions) { for (Item* item : before_instructions) {
min_position_number = if (min_position_item == nullptr ||
std::min(min_position_number, position_number_.at(instruction)); item->position < min_position_item->position) {
min_position_item = item;
}
} }
// Because more than one instruction in 'before_instructions' may have a // Because more than one instruction in 'before_instructions' may have a
// position number of 'min_position_number', find the first such instruction // position number of 'min_position_number', find the first such instruction
// with position number 'min_position_number'. // with position number 'min_position_number'.
for (auto it = instruction_iterators_.at(
first_at_position_.at(min_position_number)); // First find first instruction with the min position.
it != instructions_.end() && while (min_position_item->prev != nullptr &&
position_number_.at(*it) == min_position_number; min_position_item->position == min_position_item->prev->position) {
++it) { min_position_item = min_position_item->prev;
if (std::find(before_instructions.begin(), before_instructions.end(),
*it) != before_instructions.end()) {
return InsertBefore(to_insert, *it);
}
} }
LOG(FATAL) << "Expected to find instruction in before_instructions with "
"position number " // Now scan forwards until we find one of the before_instructions.
<< min_position_number; while (std::find(before_instructions.begin(), before_instructions.end(),
min_position_item) == before_instructions.end()) {
min_position_item = min_position_item->next;
}
return InsertBefore(to_insert, min_position_item);
}
void Blacklist(const HloInstruction* inst) {
GetItem(inst)->blacklisted = true;
} }
private: private:
// List of instructions. // Insert instruction 'item' immediately before 'before' in the list.
std::list<HloInstruction*> instructions_; void InsertBefore(Item* item, Item* before) {
VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
<< before->instruction->name();
// Insert new item into linked list.
item->prev = before->prev;
item->next = before;
before->prev = item;
if (item->prev != nullptr) {
item->prev->next = item;
} else {
first_ = item;
}
// Iterators for each instruction in the list. // Assign the same position number to the newly added instruction as
tensorflow::gtl::FlatMap<const HloInstruction*, // 'before'. This guarantees monotonicity of the position numbers, but not
std::list<HloInstruction*>::iterator> // uniqueness.
instruction_iterators_; item->position = before->position;
}
// A number assigned to each instruction which increases monotonically through Item* first_;
// 'instructions_'. Used to facilitate fast insertion of an instruction before
// the earliest instruction in a set of instructions
// (InsertBeforeInstructions) by enabling fast-ish ordering queries between
// instructions. If position_number_[a] < position_number_[b] then 'a' comes
// before 'b' in the list. If the position numbers are the same then nothing
// can be said about their order without examining the list.
//
// On object construction this value is precisely the instruction's ordinal
// position in the list. Instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
tensorflow::gtl::FlatMap<const HloInstruction*, int64> position_number_;
// The first instruction in the list assigned a particular position number. // Item for each instruction.
tensorflow::gtl::FlatMap<int64, const HloInstruction*> first_at_position_; tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
}; };
// Return the HloInstructions which use the given LogicalBuffer. Sets // Return the items which use the given LogicalBuffer. Sets
// has_indirect_users to whether any of the uses is indirect. A use is indirect // has_indirect_users to whether any of the uses is indirect. A use is indirect
// if the instruction defining logical_buffer is not an operand of the use. This // if the instruction defining logical_buffer is not an operand of the use. This
// can happen via buffer aliasing (eg, tuples). // can happen via buffer aliasing (eg, tuples).
std::vector<const HloInstruction*> GetUsers( ItemList GetUsers(const InstructionList& instruction_list,
const LogicalBuffer* logical_buffer, const LogicalBuffer* logical_buffer,
const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { const TuplePointsToAnalysis& points_to_analysis,
std::vector<const HloInstruction*> users; bool* has_indirect_users) {
ItemList users;
// To identify uses iterate through all HloInstruction users of the // To identify uses iterate through all HloInstruction users of the
// BufferAliases of the logical buffer. // BufferAliases of the logical buffer.
*has_indirect_users = false; *has_indirect_users = false;
@ -219,8 +301,9 @@ std::vector<const HloInstruction*> GetUsers(
} }
// A buffer may be used by the instruction via more than one alias. For // A buffer may be used by the instruction via more than one alias. For
// example, a buffer which appears in more than one element of a tuple. // example, a buffer which appears in more than one element of a tuple.
if (std::find(users.begin(), users.end(), user) == users.end()) { Item* user_item = instruction_list.GetItem(user);
users.push_back(user); if (std::find(users.begin(), users.end(), user_item) == users.end()) {
users.push_back(user_item);
} }
} }
} }
@ -244,7 +327,7 @@ class MemoryUsageTracker {
// EndInstruction) to accurately model memory usage. At BeginInstruction the // EndInstruction) to accurately model memory usage. At BeginInstruction the
// memory for the output value(s) of the current instruction is allocated. At // memory for the output value(s) of the current instruction is allocated. At
// EndInstruction memory for dead operand(s) is freed. // EndInstruction memory for dead operand(s) is freed.
Status BeginInstruction(const HloInstruction* instruction); Status BeginInstruction(Item* item);
// Finishes the placement of the current instruction. This frees any dead // Finishes the placement of the current instruction. This frees any dead
// operands or dead result of the instruction. This must be called after // operands or dead result of the instruction. This must be called after
@ -253,40 +336,31 @@ class MemoryUsageTracker {
// Returns the number of bytes that the current memory usage will be reduced // Returns the number of bytes that the current memory usage will be reduced
// if the given instruction is rematerialized. // if the given instruction is rematerialized.
int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const; int64 MemoryReducedIfRematerialized(Item* item) const;
// Adjusts memory usage to account for the rematerialization of // Adjusts memory usage to account for the rematerialization of
// original_instruction for all remaining unplaced uses. The rematerialization // original_item for all remaining unplaced uses. The rematerialization
// is remat_instruction. This method should be called after the HLO graph has // is remat_item. This method should be called after the HLO graph has
// been transformed (rematerialization instruction created and connected to // been transformed (rematerialization instruction created and connected to
// uses). // uses).
Status AddRematerializedInstruction(HloInstruction* original_instruction, Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
HloInstruction* remat_instruction);
// Returns whether the given instruction has been placed (BeginInstruction // Returns whether the given instruction has been placed (BeginInstruction
// has been called with 'instruction' as the argument). // has been called with 'instruction' as the argument).
bool IsPlaced(const HloInstruction* instruction) const { bool IsPlaced(const HloInstruction* instruction) const {
return ContainsKey(placed_instructions_, instruction); return instruction_list_.GetItem(instruction)->placed;
} }
// Returns the current memory usage. This is the sum of sizes of all live // Returns the current memory usage. This is the sum of sizes of all live
// values. // values.
int64 memory_usage() const { return memory_usage_; } int64 memory_usage() const { return memory_usage_; }
// Returns the current instruction being placed.
const HloInstruction* in_progress_instruction() const {
return in_progress_instruction_;
}
// Check invariants of the data structure. This is expensive to call. // Check invariants of the data structure. This is expensive to call.
bool Check() const; bool Check() const;
string ToString() const; string ToString() const;
private: private:
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
// A Buffer represents a single LogicalBuffer in the computation including // A Buffer represents a single LogicalBuffer in the computation including
// various metadata useful for tracking liveness of the value. A LogicalBuffer // various metadata useful for tracking liveness of the value. A LogicalBuffer
// is not used directly because the HLO graph is transformed and // is not used directly because the HLO graph is transformed and
@ -298,7 +372,7 @@ class MemoryUsageTracker {
const BufferId id; const BufferId id;
// The instruction which defines this buffer. // The instruction which defines this buffer.
const HloInstruction* defining_instruction; Item* defining_instruction;
// The materialized size of the buffer in bytes. // The materialized size of the buffer in bytes.
const int64 size; const int64 size;
@ -312,16 +386,17 @@ class MemoryUsageTracker {
bool has_indirect_uses; bool has_indirect_uses;
// The instructions which use this buffer. // The instructions which use this buffer.
std::vector<const HloInstruction*> users; ItemList users;
// The number of users (HloInstructions) of this buffer which have not yet // The number of users (HloInstructions) of this buffer which have not yet
// been placed in the sequence. // been placed in the sequence.
int64 unfinished_user_count; int64 unfinished_user_count;
string ToString() const { string ToString() const {
return tensorflow::strings::StrCat("Buffer ", id, " (defined by ", return tensorflow::strings::StrCat(
defining_instruction->name(), "Buffer ", id, " (defined by ",
", size ", size, " bytes)"); defining_instruction->instruction->name(), ", size ", size,
" bytes)");
} }
}; };
@ -333,25 +408,24 @@ class MemoryUsageTracker {
const HloRematerialization::ShapeSizeFunction& size_function, const HloRematerialization::ShapeSizeFunction& size_function,
bool live_out) { bool live_out) {
bool has_indirect_uses = false; bool has_indirect_uses = false;
std::vector<const HloInstruction*> users = ItemList users = GetUsers(instruction_list_, logical_buffer,
GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses); points_to_analysis, &has_indirect_uses);
return NewBuffer(logical_buffer->instruction(), return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
size_function(logical_buffer->shape()), std::move(users), size_function(logical_buffer->shape()), std::move(users),
live_out, has_indirect_uses); live_out, has_indirect_uses);
} }
// Create a new buffer representing a rematerialization of given buffer for // Create a new buffer representing a rematerialization of given buffer for
// the given uses. // the given uses.
Buffer& RematerializeBuffer( Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
const Buffer& original_buffer, const HloInstruction* remat_instruction, ItemList&& rematerialized_uses) {
std::vector<const HloInstruction*>&& rematerialized_uses) { CHECK(original_buffer.defining_instruction->placed);
CHECK(IsPlaced(original_buffer.defining_instruction));
CHECK(!original_buffer.has_indirect_uses); CHECK(!original_buffer.has_indirect_uses);
CHECK(!original_buffer.live_out); CHECK(!original_buffer.live_out);
for (const HloInstruction* use : rematerialized_uses) { for (Item* use : rematerialized_uses) {
CHECK(!IsPlaced(use)); CHECK(!use->placed);
} }
return NewBuffer(remat_instruction, original_buffer.size, return NewBuffer(remat_item, original_buffer.size,
std::move(rematerialized_uses), /*live_out=*/false, std::move(rematerialized_uses), /*live_out=*/false,
/*has_indirect_uses=*/false); /*has_indirect_uses=*/false);
} }
@ -362,7 +436,7 @@ class MemoryUsageTracker {
// different computation. // different computation.
int64 AllocatedSize(BufferId buffer_id) const { int64 AllocatedSize(BufferId buffer_id) const {
const Buffer& buffer = buffers_.at(buffer_id); const Buffer& buffer = buffers_.at(buffer_id);
HloOpcode def_opcode = buffer.defining_instruction->opcode(); HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
if (buffer.live_out || def_opcode == HloOpcode::kParameter) { if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
return 0; return 0;
} else { } else {
@ -372,18 +446,17 @@ class MemoryUsageTracker {
// Returns true if BeginInstruction and EndInstruction has been called for the // Returns true if BeginInstruction and EndInstruction has been called for the
// given instruction. // given instruction.
bool IsFinished(const HloInstruction* instruction) const { bool IsFinished(Item* item) const {
return IsPlaced(instruction) && instruction != in_progress_instruction_; return item->placed && item != in_progress_item_;
} }
// Returns whether the given buffer is being used by the in-progress // Returns whether the given buffer is being used by the in-progress
// instruction. // instruction.
bool IsInUse(BufferId buffer_id) const { bool IsInUse(BufferId buffer_id) const {
if (in_progress_instruction_ == nullptr) { if (in_progress_item_ == nullptr) {
return false; return false;
} }
const std::vector<BufferId>& in_progress_uses = const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
buffers_used_by_instruction_.at(in_progress_instruction_);
return std::find(in_progress_uses.begin(), in_progress_uses.end(), return std::find(in_progress_uses.begin(), in_progress_uses.end(),
buffer_id) != in_progress_uses.end(); buffer_id) != in_progress_uses.end();
} }
@ -392,14 +465,13 @@ class MemoryUsageTracker {
// point. // point.
bool IsCurrentlyLive(BufferId buffer_id) const { bool IsCurrentlyLive(BufferId buffer_id) const {
const Buffer& buffer = buffers_[buffer_id]; const Buffer& buffer = buffers_[buffer_id];
return (IsPlaced(buffer.defining_instruction) && return (buffer.defining_instruction->placed &&
buffer.unfinished_user_count > 0); buffer.unfinished_user_count > 0);
} }
// Create a new buffer, add it to buffers_, and return a reference. // Create a new buffer, add it to buffers_, and return a reference.
Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size, Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
std::vector<const HloInstruction*>&& users, bool live_out, bool live_out, bool has_indirect_uses) {
bool has_indirect_uses) {
int buffer_id = buffers_.size(); int buffer_id = buffers_.size();
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
has_indirect_uses, users, has_indirect_uses, users,
@ -419,19 +491,7 @@ class MemoryUsageTracker {
// The instruction currently being placed. This value is non-null only // The instruction currently being placed. This value is non-null only
// between the calling of BeginInstruction and EndInstruction. // between the calling of BeginInstruction and EndInstruction.
const HloInstruction* in_progress_instruction_ = nullptr; Item* in_progress_item_ = nullptr;
// The buffers defined by each instruction.
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
buffers_defined_by_instruction_;
// The buffers used by each instruction.
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
buffers_used_by_instruction_;
// The set of instructions which have been placed. That is, BeginInstruction
// has been called with the instruction as an argument.
tensorflow::gtl::FlatSet<const HloInstruction*> placed_instructions_;
// All buffers in the computation. // All buffers in the computation.
std::vector<Buffer> buffers_; std::vector<Buffer> buffers_;
@ -443,22 +503,15 @@ MemoryUsageTracker::MemoryUsageTracker(
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list) const InstructionList& instruction_list)
: computation_(computation), instruction_list_(instruction_list) { : computation_(computation), instruction_list_(instruction_list) {
// Iterate through all LogicalBuffers in the computation and gather the
// instructions which define them in buffers_defined_by_instruction_ and the
// instructions which use them in buffers_used_by_instruction_.
for (auto& instruction : computation_->instructions()) {
// Initialize empty vectors for defs and uses of each instruction.
buffers_used_by_instruction_[instruction.get()];
buffers_defined_by_instruction_[instruction.get()];
}
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set = tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction()) points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet(); .CreateFlattenedSet();
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId> tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
logical_buffer_to_buffer_id; logical_buffer_to_buffer_id;
for (const HloInstruction* instruction : instruction_list_.instructions()) { for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* const instruction = item->instruction;
for (const LogicalBuffer* logical_buffer : for (const LogicalBuffer* logical_buffer :
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
Buffer* buffer; Buffer* buffer;
@ -481,22 +534,22 @@ MemoryUsageTracker::MemoryUsageTracker(
// Add users of while to Buffer users. // Add users of while to Buffer users.
bool unused; bool unused;
for (const HloInstruction* user : for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
GetUsers(logical_buffer, points_to_analysis, &unused)) { points_to_analysis, &unused)) {
if (std::find(buffer->users.begin(), buffer->users.end(), user) == if (std::find(buffer->users.begin(), buffer->users.end(),
buffer->users.end()) { user_item) == buffer->users.end()) {
buffer->users.push_back(user); buffer->users.push_back(user_item);
buffer->unfinished_user_count++; buffer->unfinished_user_count++;
buffers_used_by_instruction_.at(user).push_back(buffer->id); user_item->buffers_used.push_back(buffer->id);
} }
} }
} else { } else {
buffer = &CreateBufferFromLogicalBuffer( buffer = &CreateBufferFromLogicalBuffer(
logical_buffer, points_to_analysis, size_function, logical_buffer, points_to_analysis, size_function,
ContainsKey(live_out_set, logical_buffer)); ContainsKey(live_out_set, logical_buffer));
buffers_defined_by_instruction_.at(instruction).push_back(buffer->id); item->buffers_defined.push_back(buffer->id);
for (const HloInstruction* user : buffer->users) { for (Item* user : buffer->users) {
buffers_used_by_instruction_.at(user).push_back(buffer->id); user->buffers_used.push_back(buffer->id);
} }
} }
@ -507,15 +560,16 @@ MemoryUsageTracker::MemoryUsageTracker(
DCHECK(Check()); DCHECK(Check());
} }
Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { Status MemoryUsageTracker::BeginInstruction(Item* item) {
const HloInstruction* instruction = item->instruction;
VLOG(3) << "BeginInstruction " << instruction->name(); VLOG(3) << "BeginInstruction " << instruction->name();
TF_RET_CHECK(in_progress_instruction_ == nullptr); TF_RET_CHECK(in_progress_item_ == nullptr);
in_progress_instruction_ = instruction; in_progress_item_ = item;
placed_instructions_.insert(in_progress_instruction_); item->placed = true;
// All buffers defined by this instruction need memory. // All buffers defined by this instruction need memory.
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { for (BufferId buffer_id : item->buffers_defined) {
VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
<< " is now live."; << " is now live.";
memory_usage_ += AllocatedSize(buffer_id); memory_usage_ += AllocatedSize(buffer_id);
@ -532,11 +586,10 @@ Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
} }
Status MemoryUsageTracker::EndInstruction() { Status MemoryUsageTracker::EndInstruction() {
TF_RET_CHECK(in_progress_instruction_ != nullptr); TF_RET_CHECK(in_progress_item_ != nullptr);
VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
for (BufferId buffer_id : for (BufferId buffer_id : in_progress_item_->buffers_used) {
buffers_used_by_instruction_.at(in_progress_instruction_)) {
Buffer& buffer = buffers_.at(buffer_id); Buffer& buffer = buffers_.at(buffer_id);
buffer.unfinished_user_count--; buffer.unfinished_user_count--;
CHECK_GE(buffer.unfinished_user_count, 0) CHECK_GE(buffer.unfinished_user_count, 0)
@ -551,8 +604,7 @@ Status MemoryUsageTracker::EndInstruction() {
// If any buffer defined by this instruction has no uses, then memory can be // If any buffer defined by this instruction has no uses, then memory can be
// reclaimed immediately. // reclaimed immediately.
for (BufferId buffer_id : for (BufferId buffer_id : in_progress_item_->buffers_defined) {
buffers_defined_by_instruction_.at(in_progress_instruction_)) {
const Buffer& buffer = buffers_.at(buffer_id); const Buffer& buffer = buffers_.at(buffer_id);
if (buffer.unfinished_user_count == 0) { if (buffer.unfinished_user_count == 0) {
VLOG(3) << " " << buffer.ToString() << " is immediately dead."; VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
@ -561,7 +613,7 @@ Status MemoryUsageTracker::EndInstruction() {
} }
} }
in_progress_instruction_ = nullptr; in_progress_item_ = nullptr;
VLOG(3) << " memory usage = " << memory_usage_; VLOG(3) << " memory usage = " << memory_usage_;
VLOG(10) << ToString(); VLOG(10) << ToString();
@ -571,10 +623,9 @@ Status MemoryUsageTracker::EndInstruction() {
return Status::OK(); return Status::OK();
} }
int64 MemoryUsageTracker::MemoryReducedIfRematerialized( int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
const HloInstruction* instruction) const { CHECK_NE(in_progress_item_, nullptr);
CHECK_NE(in_progress_instruction_, nullptr); if (!item->placed || item == in_progress_item_) {
if (!IsPlaced(instruction) || instruction == in_progress_instruction_) {
return 0; return 0;
} }
@ -589,7 +640,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
// be live at this program point, so initially set memory_reduced to the // be live at this program point, so initially set memory_reduced to the
// size of its defined values. // size of its defined values.
int64 memory_reduced = 0; int64 memory_reduced = 0;
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { for (BufferId buffer_id : item->buffers_defined) {
// Avoid rematerializing instructions with indirect uses as it is difficult // Avoid rematerializing instructions with indirect uses as it is difficult
// to reason about liveness after rematerializing the instruction. // to reason about liveness after rematerializing the instruction.
// TODO(b/37714814): Consider rematerialzing instructions with indirect // TODO(b/37714814): Consider rematerialzing instructions with indirect
@ -605,7 +656,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
// Account for any logical buffers whose live range must be extended across // Account for any logical buffers whose live range must be extended across
// this program point. // this program point.
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { for (BufferId buffer_id : item->buffers_used) {
if (!IsCurrentlyLive(buffer_id)) { if (!IsCurrentlyLive(buffer_id)) {
// This logical buffer is used by 'instruction' but is not live at this // This logical buffer is used by 'instruction' but is not live at this
// program point. Rematerializing 'instruction' will extend the buffer's // program point. Rematerializing 'instruction' will extend the buffer's
@ -617,28 +668,23 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
return memory_reduced; return memory_reduced;
} }
Status MemoryUsageTracker::AddRematerializedInstruction( Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
HloInstruction* original_instruction, HloInstruction* remat_instruction) { Item* remat_item) {
VLOG(3) << "AddRematerializedInstruction: original_instruction = " VLOG(3) << "AddRematerializedInstruction: original_instruction = "
<< original_instruction->name() << original_item->instruction->name()
<< ", remat_instruction = " << remat_instruction->name(); << ", remat_instruction = " << remat_item->instruction->name();
TF_RET_CHECK(in_progress_instruction_ != nullptr); TF_RET_CHECK(in_progress_item_ != nullptr);
TF_RET_CHECK(IsPlaced(original_instruction)); TF_RET_CHECK(original_item->placed);
TF_RET_CHECK(!IsPlaced(remat_instruction)); TF_RET_CHECK(!remat_item->placed);
CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction));
CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction));
// Construct the list of buffers used and defined by the rematerialization. // Construct the list of buffers used and defined by the rematerialization.
buffers_defined_by_instruction_[remat_instruction]; remat_item->buffers_used = original_item->buffers_used;
buffers_used_by_instruction_[remat_instruction] =
buffers_used_by_instruction_.at(original_instruction);
// Account for the additional buffer uses created by the new rematerialization // Account for the additional buffer uses created by the new rematerialization
// instruction. Update memory usage if the rematerialization makes a dead // instruction. Update memory usage if the rematerialization makes a dead
// buffer live again. // buffer live again.
for (BufferId buffer_id : for (BufferId buffer_id : original_item->buffers_used) {
buffers_used_by_instruction_.at(original_instruction)) {
Buffer& buffer = buffers_.at(buffer_id); Buffer& buffer = buffers_.at(buffer_id);
if (buffer.unfinished_user_count == 0) { if (buffer.unfinished_user_count == 0) {
// Buffer used by this instruction was dead, now is alive. // Buffer used by this instruction was dead, now is alive.
@ -646,20 +692,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
} }
buffer.unfinished_user_count++; buffer.unfinished_user_count++;
buffer.users.push_back(remat_instruction); buffer.users.push_back(remat_item);
} }
// Create a new set of Buffers defined by the new rematerialization // Create a new set of Buffers defined by the new rematerialization
// instruction. Update the internal data structures and memory use to account // instruction. Update the internal data structures and memory use to account
// for them. // for them.
for (BufferId old_buffer_id : for (BufferId old_buffer_id : original_item->buffers_defined) {
buffers_defined_by_instruction_.at(original_instruction)) {
Buffer& old_buffer = buffers_.at(old_buffer_id); Buffer& old_buffer = buffers_.at(old_buffer_id);
std::vector<const HloInstruction*> placed_users; ItemList placed_users;
std::vector<const HloInstruction*> unplaced_users; ItemList unplaced_users;
for (const HloInstruction* user : old_buffer.users) { for (Item* user : old_buffer.users) {
if (IsPlaced(user)) { if (user->placed) {
CHECK(IsFinished(user)); CHECK(IsFinished(user));
placed_users.push_back(user); placed_users.push_back(user);
} else { } else {
@ -672,14 +717,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
// Buffer is now dead. // Buffer is now dead.
memory_usage_ -= AllocatedSize(old_buffer.id); memory_usage_ -= AllocatedSize(old_buffer.id);
Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction, Buffer& new_buffer =
std::move(unplaced_users)); RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
buffers_defined_by_instruction_.at(remat_instruction) remat_item->buffers_defined.push_back(new_buffer.id);
.push_back(new_buffer.id); for (Item* user : new_buffer.users) {
for (const HloInstruction* user : new_buffer.users) { BufferIdList& buffers_used = user->buffers_used;
std::vector<BufferId>& buffers_used =
buffers_used_by_instruction_.at(user);
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
new_buffer.id); new_buffer.id);
} }
@ -699,13 +742,14 @@ string MemoryUsageTracker::ToString() const {
tensorflow::strings::StrAppend( tensorflow::strings::StrAppend(
&output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
memory_usage(), " bytes)"); memory_usage(), " bytes)");
for (const HloInstruction* instruction : instruction_list_.instructions()) { for (auto* item = instruction_list_.first(); item != nullptr;
string inprogress = item = instruction_list_.next(item)) {
instruction == in_progress_instruction_ ? " in-progress" : ""; const HloInstruction* instruction = item->instruction;
string placed = IsPlaced(instruction) ? " placed" : ""; string inprogress = item == in_progress_item_ ? " in-progress" : "";
string placed = item->placed ? " placed" : "";
tensorflow::strings::StrAppend(&output, " ", instruction->name(), tensorflow::strings::StrAppend(&output, " ", instruction->name(),
inprogress, placed, "\n Defines:\n"); inprogress, placed, "\n Defines:\n");
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_[buffer_id]; const Buffer& buffer = buffers_[buffer_id];
string live = IsCurrentlyLive(buffer_id) ? " live" : ""; string live = IsCurrentlyLive(buffer_id) ? " live" : "";
tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
@ -713,7 +757,7 @@ string MemoryUsageTracker::ToString() const {
" unfinished uses\n"); " unfinished uses\n");
} }
tensorflow::strings::StrAppend(&output, " Uses:\n"); tensorflow::strings::StrAppend(&output, " Uses:\n");
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { for (BufferId buffer_id : item->buffers_used) {
tensorflow::strings::StrAppend(&output, " ", tensorflow::strings::StrAppend(&output, " ",
buffers_[buffer_id].ToString(), "\n"); buffers_[buffer_id].ToString(), "\n");
} }
@ -722,14 +766,14 @@ string MemoryUsageTracker::ToString() const {
} }
bool MemoryUsageTracker::Check() const { bool MemoryUsageTracker::Check() const {
auto elements_are_unique = [](const std::vector<BufferId>& vec) { auto elements_are_unique = [](const BufferIdList& vec) {
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size(); return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
}; };
// Verify buffers_defined_by_instruction_. // Verify buffers_defined per instruction.
for (auto& instruction : computation_->instructions()) { for (auto& instruction : computation_->instructions()) {
const std::vector<BufferId>& defined_buffers = const BufferIdList& defined_buffers =
buffers_defined_by_instruction_.at(instruction.get()); instruction_list_.GetItem(instruction.get())->buffers_defined;
CHECK(elements_are_unique(defined_buffers)) CHECK(elements_are_unique(defined_buffers))
<< "Instruction " << instruction->name() << "Instruction " << instruction->name()
<< " does not have unique defined buffers: " << " does not have unique defined buffers: "
@ -740,7 +784,7 @@ bool MemoryUsageTracker::Check() const {
}); });
for (const Buffer& buffer : buffers_) { for (const Buffer& buffer : buffers_) {
if (buffer.defining_instruction == instruction.get()) { if (buffer.defining_instruction->instruction == instruction.get()) {
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
buffer.id) != defined_buffers.end()) buffer.id) != defined_buffers.end())
<< "Instruction " << instruction->name() << "Instruction " << instruction->name()
@ -749,10 +793,10 @@ bool MemoryUsageTracker::Check() const {
} }
} }
// Verify buffers_used_by_instruction_. // Verify buffers_used per instruction.
for (auto& instruction : computation_->instructions()) { for (auto& instruction : computation_->instructions()) {
const std::vector<BufferId>& used_buffers = const BufferIdList& used_buffers =
buffers_used_by_instruction_.at(instruction.get()); instruction_list_.GetItem(instruction.get())->buffers_used;
CHECK(elements_are_unique(used_buffers)) CHECK(elements_are_unique(used_buffers))
<< "Instruction " << instruction->name() << "Instruction " << instruction->name()
<< " does not have unique used buffers: " << " does not have unique used buffers: "
@ -764,13 +808,12 @@ bool MemoryUsageTracker::Check() const {
} }
for (const Buffer& buffer : buffers_) { for (const Buffer& buffer : buffers_) {
int64 unfinished_uses = 0; int64 unfinished_uses = 0;
for (const HloInstruction* user : buffer.users) { for (Item* user : buffer.users) {
const std::vector<BufferId>& used_buffers = const BufferIdList& used_buffers = user->buffers_used;
buffers_used_by_instruction_.at(user);
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
used_buffers.end()) used_buffers.end())
<< "Instruction " << user->name() << " used buffers is missing " << "Instruction " << user->instruction->name()
<< buffer.ToString(); << " used buffers is missing " << buffer.ToString();
if (!IsFinished(user)) { if (!IsFinished(user)) {
unfinished_uses++; unfinished_uses++;
} }
@ -785,8 +828,8 @@ bool MemoryUsageTracker::Check() const {
// The while instruction reuses its input buffers as output buffers so // The while instruction reuses its input buffers as output buffers so
// don't double count its buffers if it is currently executing. // don't double count its buffers if it is currently executing.
if (IsCurrentlyLive(buffer.id) && if (IsCurrentlyLive(buffer.id) &&
!(buffer.defining_instruction == in_progress_instruction_ && !(buffer.defining_instruction == in_progress_item_ &&
in_progress_instruction_->opcode() == HloOpcode::kWhile)) { in_progress_item_->instruction->opcode() == HloOpcode::kWhile)) {
live_size += AllocatedSize(buffer.id); live_size += AllocatedSize(buffer.id);
} }
} }
@ -830,26 +873,26 @@ int64 RematerializationCost(const HloInstruction* instruction,
// candidate which reduce memory use at the program point of the current // candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no // instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found. // candidate can be found.
HloInstruction* PickRematerializationCandidate( Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list,
const InstructionList& instruction_list, int64 memory_limit_bytes) {
const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist, Item* best_item = nullptr;
int64 memory_limit_bytes) {
HloInstruction* best = nullptr;
int64 best_cost = 0; int64 best_cost = 0;
// TODO(b/35244891): This is currently quadratic in the number of HLO // TODO(b/35244891): This is currently quadratic in the number of HLO
// instructions. // instructions.
for (HloInstruction* candidate : instruction_list.instructions()) { for (auto* item = instruction_list.first(); item != nullptr;
if (!memory_tracker.IsPlaced(candidate)) { item = instruction_list.next(item)) {
// Only iterate up to the currently placed instruction as indicated by if (!item->placed) {
// memory_tracker. We are trying to reduce memory usage at the placed // Only iterate up to the currently placed instruction.
// We are trying to reduce memory usage at the placed
// instruction so rematerializing later values is of no benefit. // instruction so rematerializing later values is of no benefit.
break; break;
} }
HloInstruction* candidate = item->instruction;
VLOG(5) << "considering rematerialization candidate " << candidate->name(); VLOG(5) << "considering rematerialization candidate " << candidate->name();
if (ContainsKey(blacklist, candidate)) { if (item->blacklisted) {
// Skip instructions on the blacklist to avoid infinite loops of // Skip instructions on the blacklist to avoid infinite loops of
// rematerializing the same instruction(s) repeatedly. // rematerializing the same instruction(s) repeatedly.
VLOG(5) << "candidate " << candidate->name() VLOG(5) << "candidate " << candidate->name()
@ -864,7 +907,7 @@ HloInstruction* PickRematerializationCandidate(
} }
const int64 memory_reduced = const int64 memory_reduced =
memory_tracker.MemoryReducedIfRematerialized(candidate); memory_tracker.MemoryReducedIfRematerialized(item);
if (memory_reduced <= 0) { if (memory_reduced <= 0) {
VLOG(5) << "candidate " << candidate->name() VLOG(5) << "candidate " << candidate->name()
@ -878,13 +921,13 @@ HloInstruction* PickRematerializationCandidate(
VLOG(5) << "candidate " << candidate->name() << ", memory reduced " VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
<< memory_reduced << ", cost per byte " << cost; << memory_reduced << ", cost per byte " << cost;
if (best == nullptr || cost < best_cost) { if (best_item == nullptr || cost < best_cost) {
VLOG(5) << "candidate " << candidate->name() << " now best"; VLOG(5) << "candidate " << candidate->name() << " now best";
best = candidate; best_item = item;
best_cost = cost; best_cost = cost;
} }
} }
return best; return best_item;
} }
} // namespace } // namespace
@ -896,8 +939,10 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
instruction_list); instruction_list);
int64 peak_memory = tracker.memory_usage(); int64 peak_memory = tracker.memory_usage();
for (const HloInstruction* instruction : order) { for (auto* item = instruction_list.first(); item != nullptr;
TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction)); item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
TF_ASSIGN_OR_RETURN(int64 callee_usage, TF_ASSIGN_OR_RETURN(int64 callee_usage,
CalledComputationsMemoryUsage(instruction)); CalledComputationsMemoryUsage(instruction));
peak_memory = peak_memory =
@ -939,11 +984,6 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
*points_to_analysis_, instruction_list); *points_to_analysis_, instruction_list);
bool changed = false; bool changed = false;
// To avoid an infinite loop rematerializing the same set of instructions ad
// infinitum, keep a blacklist of instructions which should not be
// rematerialized.
tensorflow::gtl::FlatSet<const HloInstruction*> blacklist;
// If the rematerialization makes the source instruction dead, then the // If the rematerialization makes the source instruction dead, then the
// rematerialization is added to 'remat_move_instructions' (the // rematerialization is added to 'remat_move_instructions' (the
// rematerialization is essentially a move). If the next rematerialization of // rematerialization is essentially a move). If the next rematerialization of
@ -967,17 +1007,17 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// (program point) if memory_usage exceeds the specified limit then // (program point) if memory_usage exceeds the specified limit then
// rematerialize HLO instructions until memory_usage is reduced. // rematerialize HLO instructions until memory_usage is reduced.
int64 instruction_index = 0; int64 instruction_index = 0;
for (auto list_it = instruction_list.instructions().begin(); for (auto* item = instruction_list.first(); item != nullptr;
list_it != instruction_list.instructions().end(); ++list_it) { item = instruction_list.next(item)) {
HloInstruction* instruction = *list_it; const HloInstruction* instruction = item->instruction;
TF_ASSIGN_OR_RETURN(int64 callee_usage, TF_ASSIGN_OR_RETURN(int64 callee_usage,
CalledComputationsMemoryUsage(instruction)); CalledComputationsMemoryUsage(instruction));
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(instruction)); TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
VLOG(2) << "Program point at " << instruction->name() VLOG(2) << "Program point at " << instruction->name()
<< ", memory usage = " << memory_tracker.memory_usage() << ", memory usage = " << memory_tracker.memory_usage()
<< ", callee usage = " << callee_usage << ", [" << instruction_index << ", callee usage = " << callee_usage << ", [" << instruction_index
<< "/" << instruction_list.instructions().size() << "]"; << "/" << instruction_list.size() << "]";
instruction_index++; instruction_index++;
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
@ -987,10 +1027,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
callee_usage) callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes); << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
HloInstruction* best = PickRematerializationCandidate( Item* best_item = PickRematerializationCandidate(
memory_tracker, instruction_list, blacklist, memory_limit_bytes); memory_tracker, instruction_list, memory_limit_bytes);
if (best == nullptr) { if (best_item == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program " VLOG(3) << "Unable to find rematerialization candidate at program "
"point " "point "
<< instruction->name() << ". Memory usage = " << instruction->name() << ". Memory usage = "
@ -999,13 +1039,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
break; break;
} }
HloInstruction* best = best_item->instruction;
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
<< memory_tracker.MemoryReducedIfRematerialized(best) << ")"; << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")";
changed = true; changed = true;
remat_count++; remat_count++;
HloInstruction* remat = HloInstruction* remat =
computation->AddInstruction(best->Clone(/*suffix=*/"remat")); computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
Item* remat_item = instruction_list.CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization. // Replace each remaining use of 'best' with the rematerialization.
std::vector<HloInstruction*> best_users_copy = best->users(); std::vector<HloInstruction*> best_users_copy = best->users();
@ -1019,22 +1061,28 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// Account for the rematerialization in the memory tracker. // Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
memory_tracker.AddRematerializedInstruction(best, remat)); memory_tracker.AddRematerializedInstruction(best_item, remat_item));
// Insert rematerialized instruction right before the earliest unplaced // Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any // use of the instruction *and* the earliest unplaced last use of any
// operands of remat. Unplaced uses of the remat's operands are included // operands of remat. Unplaced uses of the remat's operands are included
// because we don't want to extend the live range of remat's operands as // because we don't want to extend the live range of remat's operands as
// this could increase memory usage. // this could increase memory usage.
std::vector<HloInstruction*> place_before = remat->users(); ItemList place_before;
for (auto user : remat->users()) {
place_before.push_back(instruction_list.GetItem(user));
}
for (auto* operand : remat->operands()) { for (auto* operand : remat->operands()) {
for (auto* operand_user : operand->users()) { for (auto* operand_user : operand->users()) {
if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) { if (operand_user != remat) {
place_before.push_back(operand_user); Item* operand_user_item = instruction_list.GetItem(operand_user);
if (!operand_user_item->placed) {
place_before.push_back(operand_user_item);
}
} }
} }
} }
instruction_list.InsertBeforeInstructions(remat, place_before); instruction_list.InsertBeforeInstructions(remat_item, place_before);
// If the rematerialized instruction is dead then rematerialization is // If the rematerialized instruction is dead then rematerialization is
// essentially a move. Don't delete the instruction now because we don't // essentially a move. Don't delete the instruction now because we don't
@ -1048,7 +1096,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// instruction it was a copying of. Now 'remat' is a rematerialization // instruction it was a copying of. Now 'remat' is a rematerialization
// of 'best' and kills 'best'. Stop rematerializing this instruction // of 'best' and kills 'best'. Stop rematerializing this instruction
// to avoid an infinite loop. // to avoid an infinite loop.
blacklist.insert(remat); instruction_list.Blacklist(remat);
} }
remat_move_instructions.insert(remat); remat_move_instructions.insert(remat);
} else { } else {
@ -1116,10 +1164,13 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
computation_peak_memory_.at(computation) = peak_memory; computation_peak_memory_.at(computation) = peak_memory;
// Update order to include rematerialized instructions. // Update order to include rematerialized instructions.
sequence->at(computation) auto& dst = sequence->at(computation);
.assign(instruction_list.instructions().begin(), dst.clear();
instruction_list.instructions().end()); for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
dst.push_back(instruction);
}
rematerialized_computations_.insert(computation); rematerialized_computations_.insert(computation);
instructions_rematerialized_ += remat_count; instructions_rematerialized_ += remat_count;