Improve performance of compilation by ~8% by speeding up the
hlo rematerialization pass. Changes: . Wrap each HloInstruction* inside an Item structure that keeps associated data. This allows us to get rid of a bunch of hash tables indexed by HloInstruction*. * Switch to an intrusive linked list (instead of std::list) so that we can avoid a hash table that maps to std::list::iterator. * Use inlined vector in a few places. PiperOrigin-RevId: 163848365
This commit is contained in:
parent
6d77a01293
commit
66f1485424
@ -47,6 +47,12 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Potential optimizations:
|
||||
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
|
||||
// of candidates.
|
||||
// . Cache IsRematerializable in Item? Only correct if control
|
||||
// predecessors and successors don't change.
|
||||
|
||||
// Returns true if the given instruction is rematerializable.
|
||||
bool IsRematerializable(const HloInstruction* instruction) {
|
||||
// Conservatively, don't rematerialize instruction with control
|
||||
@ -79,126 +85,202 @@ bool IsRematerializable(const HloInstruction* instruction) {
|
||||
}
|
||||
}
|
||||
|
||||
// Type holding a unique identifier for each Buffer object.
|
||||
using BufferId = int64;
|
||||
using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
|
||||
|
||||
// We wrap HloInstruction* with an Item that holds auxiliary
|
||||
// per-instruction state.
|
||||
struct Item {
|
||||
HloInstruction* instruction;
|
||||
|
||||
// True once the instruction is marked as placed (when BeginInstruction
|
||||
// has been called for this instruction).
|
||||
bool placed = false;
|
||||
|
||||
// To avoid an infinite loop rematerializing the same set of
|
||||
// instructions ad infinitum, keep a blacklist of instructions
|
||||
// which should not be rematerialized.
|
||||
bool blacklisted = false;
|
||||
|
||||
// The buffers defined by this instruction.
|
||||
BufferIdList buffers_defined;
|
||||
|
||||
// The buffers used by this instruction.
|
||||
BufferIdList buffers_used;
|
||||
|
||||
private:
|
||||
friend class InstructionList;
|
||||
|
||||
// Items are arranged in a doubly linked list.
|
||||
Item* next;
|
||||
Item* prev;
|
||||
|
||||
// List is ordered by position, which can however be duplicated as
|
||||
// new instructions are inserted. See InsertBeforeInstructions
|
||||
// comment for details.
|
||||
int64 position;
|
||||
};
|
||||
|
||||
using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
|
||||
|
||||
// Class which maintains an ordered list of instructions with fast insertion
|
||||
// before arbitrary elements.
|
||||
class InstructionList {
|
||||
public:
|
||||
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
|
||||
int64 position = 0;
|
||||
Item* last = nullptr;
|
||||
for (const HloInstruction* inst : order) {
|
||||
instructions_.push_back(const_cast<HloInstruction*>(inst));
|
||||
instruction_iterators_.insert({const_cast<HloInstruction*>(inst),
|
||||
std::next(instructions_.end(), -1)});
|
||||
// Add a new item to the linked list.
|
||||
Item* item = new Item;
|
||||
item->next = nullptr;
|
||||
item->prev = last;
|
||||
if (last == nullptr) {
|
||||
first_ = item;
|
||||
} else {
|
||||
last->next = item;
|
||||
}
|
||||
last = item;
|
||||
|
||||
// Initially position numbers are uniquely assigned in order. Later as
|
||||
// instructions are added with InsertBefore* methods, some instructions
|
||||
// may have duplicate position numbers, but the values will be guaranteed
|
||||
// to be monotonically increasing through the list, and so is still useful
|
||||
// for quickly(-ish) determining the order of arbitrary instructions in
|
||||
// the list.
|
||||
position_number_[inst] = position;
|
||||
first_at_position_[position] = inst;
|
||||
item->instruction = const_cast<HloInstruction*>(inst);
|
||||
item->position = position;
|
||||
position++;
|
||||
|
||||
item_map_[inst] = item;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the list of instructions.
|
||||
const std::list<HloInstruction*>& instructions() const {
|
||||
return instructions_;
|
||||
~InstructionList() {
|
||||
for (Item* item = first_; item != nullptr;) {
|
||||
Item* next = item->next;
|
||||
delete item;
|
||||
item = next;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert instruction 'to_insert' immediately before instruction 'before' in
|
||||
// the list.
|
||||
void InsertBefore(HloInstruction* to_insert, HloInstruction* before) {
|
||||
VLOG(3) << "InsertBefore: " << to_insert->name() << " before "
|
||||
<< before->name();
|
||||
auto it = instruction_iterators_.find(before);
|
||||
CHECK(it != instruction_iterators_.end());
|
||||
instruction_iterators_.insert(
|
||||
{to_insert, instructions_.insert(it->second, to_insert)});
|
||||
// Assign the same position number to the newly added instruction as
|
||||
// 'before'. This guarantees monotonicity of the position numbers, but not
|
||||
// uniqueness.
|
||||
int64 pos = position_number_.at(before);
|
||||
position_number_[to_insert] = pos;
|
||||
if (first_at_position_.at(pos) == before) {
|
||||
first_at_position_[pos] = to_insert;
|
||||
}
|
||||
size_t size() const { return item_map_.size(); }
|
||||
|
||||
// For ordered iteration over items.
|
||||
// for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
|
||||
Item* first() const { return first_; }
|
||||
Item* next(Item* item) const { return item->next; }
|
||||
|
||||
// Creates an Item for the given instruction, but doesn't add it to the list.
|
||||
// (Use InsertBeforeInstructions to add the Item to the list.)
|
||||
Item* CreateItem(HloInstruction* inst) {
|
||||
Item* item = new Item;
|
||||
item->instruction = inst;
|
||||
CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice";
|
||||
return item;
|
||||
}
|
||||
|
||||
// Return the Item corresponding to inst.
|
||||
Item* GetItem(const HloInstruction* inst) const {
|
||||
auto iter = item_map_.find(inst);
|
||||
CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
// Insert instruction 'to_insert' immediately before the earliest instruction
|
||||
// in 'before_instructions'.
|
||||
//
|
||||
// Each instruction gets a non-decreasing ordinal number. We use this to let
|
||||
// InsertBeforeInstructions quickly insert an instruction before the earliest
|
||||
// instruction in a set of instructions. If position_number_[a] <
|
||||
// position_number_[b] then 'a' comes before 'b' in the list. If the position
|
||||
// numbers are the same then nothing can be said about their order without
|
||||
// examining the list.
|
||||
//
|
||||
// On object construction this ordinal is precisely the instruction's index
|
||||
// in the list. Later, instructions inserted via InsertBefore receive
|
||||
// duplicate values. However, monotonicity is preserved.
|
||||
void InsertBeforeInstructions(
|
||||
HloInstruction* to_insert,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> before_instructions) {
|
||||
VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {"
|
||||
<< tensorflow::str_util::Join(
|
||||
before_instructions, ", ",
|
||||
[](string* out, HloInstruction* inst) {
|
||||
tensorflow::strings::StrAppend(out, inst->name());
|
||||
})
|
||||
Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
|
||||
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
|
||||
<< " before {"
|
||||
<< tensorflow::str_util::Join(before_instructions, ", ",
|
||||
[](string* out, Item* item) {
|
||||
tensorflow::strings::StrAppend(
|
||||
out, item->instruction->name());
|
||||
})
|
||||
<< "}";
|
||||
|
||||
// Find the minimal position number of any instruction in
|
||||
// 'before_instructions'.
|
||||
CHECK(!before_instructions.empty());
|
||||
int64 min_position_number = std::numeric_limits<int64>::max();
|
||||
for (const HloInstruction* instruction : before_instructions) {
|
||||
min_position_number =
|
||||
std::min(min_position_number, position_number_.at(instruction));
|
||||
Item* min_position_item = nullptr;
|
||||
for (Item* item : before_instructions) {
|
||||
if (min_position_item == nullptr ||
|
||||
item->position < min_position_item->position) {
|
||||
min_position_item = item;
|
||||
}
|
||||
}
|
||||
|
||||
// Because more than one instruction in 'before_instructions' may have a
|
||||
// position number of 'min_position_number', find the first such instruction
|
||||
// with position number 'min_position_number'.
|
||||
for (auto it = instruction_iterators_.at(
|
||||
first_at_position_.at(min_position_number));
|
||||
it != instructions_.end() &&
|
||||
position_number_.at(*it) == min_position_number;
|
||||
++it) {
|
||||
if (std::find(before_instructions.begin(), before_instructions.end(),
|
||||
*it) != before_instructions.end()) {
|
||||
return InsertBefore(to_insert, *it);
|
||||
}
|
||||
|
||||
// First find first instruction with the min position.
|
||||
while (min_position_item->prev != nullptr &&
|
||||
min_position_item->position == min_position_item->prev->position) {
|
||||
min_position_item = min_position_item->prev;
|
||||
}
|
||||
LOG(FATAL) << "Expected to find instruction in before_instructions with "
|
||||
"position number "
|
||||
<< min_position_number;
|
||||
|
||||
// Now scan forwards until we find one of the before_instructions.
|
||||
while (std::find(before_instructions.begin(), before_instructions.end(),
|
||||
min_position_item) == before_instructions.end()) {
|
||||
min_position_item = min_position_item->next;
|
||||
}
|
||||
return InsertBefore(to_insert, min_position_item);
|
||||
}
|
||||
|
||||
void Blacklist(const HloInstruction* inst) {
|
||||
GetItem(inst)->blacklisted = true;
|
||||
}
|
||||
|
||||
private:
|
||||
// List of instructions.
|
||||
std::list<HloInstruction*> instructions_;
|
||||
// Insert instruction 'item' immediately before 'before' in the list.
|
||||
void InsertBefore(Item* item, Item* before) {
|
||||
VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
|
||||
<< before->instruction->name();
|
||||
// Insert new item into linked list.
|
||||
item->prev = before->prev;
|
||||
item->next = before;
|
||||
before->prev = item;
|
||||
if (item->prev != nullptr) {
|
||||
item->prev->next = item;
|
||||
} else {
|
||||
first_ = item;
|
||||
}
|
||||
|
||||
// Iterators for each instruction in the list.
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*,
|
||||
std::list<HloInstruction*>::iterator>
|
||||
instruction_iterators_;
|
||||
// Assign the same position number to the newly added instruction as
|
||||
// 'before'. This guarantees monotonicity of the position numbers, but not
|
||||
// uniqueness.
|
||||
item->position = before->position;
|
||||
}
|
||||
|
||||
// A number assigned to each instruction which increases monotonically through
|
||||
// 'instructions_'. Used to facilitate fast insertion of an instruction before
|
||||
// the earliest instruction in a set of instructions
|
||||
// (InsertBeforeInstructions) by enabling fast-ish ordering queries between
|
||||
// instructions. If position_number_[a] < position_number_[b] then 'a' comes
|
||||
// before 'b' in the list. If the position numbers are the same then nothing
|
||||
// can be said about their order without examining the list.
|
||||
//
|
||||
// On object construction this value is precisely the instruction's ordinal
|
||||
// position in the list. Instructions inserted via InsertBefore receive
|
||||
// duplicate values. However, monotonicity is preserved.
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> position_number_;
|
||||
Item* first_;
|
||||
|
||||
// The first instruction in the list assigned a particular position number.
|
||||
tensorflow::gtl::FlatMap<int64, const HloInstruction*> first_at_position_;
|
||||
// Item for each instruction.
|
||||
tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
|
||||
};
|
||||
|
||||
// Return the HloInstructions which use the given LogicalBuffer. Sets
|
||||
// Return the items which use the given LogicalBuffer. Sets
|
||||
// has_indirect_users to whether any of the uses is indirect. A use is indirect
|
||||
// if the instruction defining logical_buffer is not an operand of the use. This
|
||||
// can happen via buffer aliasing (eg, tuples).
|
||||
std::vector<const HloInstruction*> GetUsers(
|
||||
const LogicalBuffer* logical_buffer,
|
||||
const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) {
|
||||
std::vector<const HloInstruction*> users;
|
||||
ItemList GetUsers(const InstructionList& instruction_list,
|
||||
const LogicalBuffer* logical_buffer,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
bool* has_indirect_users) {
|
||||
ItemList users;
|
||||
// To identify uses iterate through all HloInstruction users of the
|
||||
// BufferAliases of the logical buffer.
|
||||
*has_indirect_users = false;
|
||||
@ -219,8 +301,9 @@ std::vector<const HloInstruction*> GetUsers(
|
||||
}
|
||||
// A buffer may be used by the instruction via more than one alias. For
|
||||
// example, a buffer which appears in more than one element of a tuple.
|
||||
if (std::find(users.begin(), users.end(), user) == users.end()) {
|
||||
users.push_back(user);
|
||||
Item* user_item = instruction_list.GetItem(user);
|
||||
if (std::find(users.begin(), users.end(), user_item) == users.end()) {
|
||||
users.push_back(user_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -244,7 +327,7 @@ class MemoryUsageTracker {
|
||||
// EndInstruction) to accurately model memory usage. At BeginInstruction the
|
||||
// memory for the output value(s) of the current instruction is allocated. At
|
||||
// EndInstruction memory for dead operand(s) is freed.
|
||||
Status BeginInstruction(const HloInstruction* instruction);
|
||||
Status BeginInstruction(Item* item);
|
||||
|
||||
// Finishes the placement of the current instruction. This frees any dead
|
||||
// operands or dead result of the instruction. This must be called after
|
||||
@ -253,40 +336,31 @@ class MemoryUsageTracker {
|
||||
|
||||
// Returns the number of bytes that the current memory usage will be reduced
|
||||
// if the given instruction is rematerialized.
|
||||
int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const;
|
||||
int64 MemoryReducedIfRematerialized(Item* item) const;
|
||||
|
||||
// Adjusts memory usage to account for the rematerialization of
|
||||
// original_instruction for all remaining unplaced uses. The rematerialization
|
||||
// is remat_instruction. This method should be called after the HLO graph has
|
||||
// original_item for all remaining unplaced uses. The rematerialization
|
||||
// is remat_item. This method should be called after the HLO graph has
|
||||
// been transformed (rematerialization instruction created and connected to
|
||||
// uses).
|
||||
Status AddRematerializedInstruction(HloInstruction* original_instruction,
|
||||
HloInstruction* remat_instruction);
|
||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
|
||||
|
||||
// Returns whether the given instruction has been placed (BeginInstruction
|
||||
// has been called with 'instruction' as the argument).
|
||||
bool IsPlaced(const HloInstruction* instruction) const {
|
||||
return ContainsKey(placed_instructions_, instruction);
|
||||
return instruction_list_.GetItem(instruction)->placed;
|
||||
}
|
||||
|
||||
// Returns the current memory usage. This is the sum of sizes of all live
|
||||
// values.
|
||||
int64 memory_usage() const { return memory_usage_; }
|
||||
|
||||
// Returns the current instruction being placed.
|
||||
const HloInstruction* in_progress_instruction() const {
|
||||
return in_progress_instruction_;
|
||||
}
|
||||
|
||||
// Check invariants of the data structure. This is expensive to call.
|
||||
bool Check() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
// Type holding a unique identifier for each Buffer object.
|
||||
using BufferId = int64;
|
||||
|
||||
// A Buffer represents a single LogicalBuffer in the computation including
|
||||
// various metadata useful for tracking liveness of the value. A LogicalBuffer
|
||||
// is not used directly because the HLO graph is transformed and
|
||||
@ -298,7 +372,7 @@ class MemoryUsageTracker {
|
||||
const BufferId id;
|
||||
|
||||
// The instruction which defines this buffer.
|
||||
const HloInstruction* defining_instruction;
|
||||
Item* defining_instruction;
|
||||
|
||||
// The materialized size of the buffer in bytes.
|
||||
const int64 size;
|
||||
@ -312,16 +386,17 @@ class MemoryUsageTracker {
|
||||
bool has_indirect_uses;
|
||||
|
||||
// The instructions which use this buffer.
|
||||
std::vector<const HloInstruction*> users;
|
||||
ItemList users;
|
||||
|
||||
// The number of users (HloInstructions) of this buffer which have not yet
|
||||
// been placed in the sequence.
|
||||
int64 unfinished_user_count;
|
||||
|
||||
string ToString() const {
|
||||
return tensorflow::strings::StrCat("Buffer ", id, " (defined by ",
|
||||
defining_instruction->name(),
|
||||
", size ", size, " bytes)");
|
||||
return tensorflow::strings::StrCat(
|
||||
"Buffer ", id, " (defined by ",
|
||||
defining_instruction->instruction->name(), ", size ", size,
|
||||
" bytes)");
|
||||
}
|
||||
};
|
||||
|
||||
@ -333,25 +408,24 @@ class MemoryUsageTracker {
|
||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||
bool live_out) {
|
||||
bool has_indirect_uses = false;
|
||||
std::vector<const HloInstruction*> users =
|
||||
GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses);
|
||||
return NewBuffer(logical_buffer->instruction(),
|
||||
ItemList users = GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &has_indirect_uses);
|
||||
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
||||
size_function(logical_buffer->shape()), std::move(users),
|
||||
live_out, has_indirect_uses);
|
||||
}
|
||||
|
||||
// Create a new buffer representing a rematerialization of given buffer for
|
||||
// the given uses.
|
||||
Buffer& RematerializeBuffer(
|
||||
const Buffer& original_buffer, const HloInstruction* remat_instruction,
|
||||
std::vector<const HloInstruction*>&& rematerialized_uses) {
|
||||
CHECK(IsPlaced(original_buffer.defining_instruction));
|
||||
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
|
||||
ItemList&& rematerialized_uses) {
|
||||
CHECK(original_buffer.defining_instruction->placed);
|
||||
CHECK(!original_buffer.has_indirect_uses);
|
||||
CHECK(!original_buffer.live_out);
|
||||
for (const HloInstruction* use : rematerialized_uses) {
|
||||
CHECK(!IsPlaced(use));
|
||||
for (Item* use : rematerialized_uses) {
|
||||
CHECK(!use->placed);
|
||||
}
|
||||
return NewBuffer(remat_instruction, original_buffer.size,
|
||||
return NewBuffer(remat_item, original_buffer.size,
|
||||
std::move(rematerialized_uses), /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
}
|
||||
@ -362,7 +436,7 @@ class MemoryUsageTracker {
|
||||
// different computation.
|
||||
int64 AllocatedSize(BufferId buffer_id) const {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
HloOpcode def_opcode = buffer.defining_instruction->opcode();
|
||||
HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
|
||||
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
|
||||
return 0;
|
||||
} else {
|
||||
@ -372,18 +446,17 @@ class MemoryUsageTracker {
|
||||
|
||||
// Returns true if BeginInstruction and EndInstruction has been called for the
|
||||
// given instruction.
|
||||
bool IsFinished(const HloInstruction* instruction) const {
|
||||
return IsPlaced(instruction) && instruction != in_progress_instruction_;
|
||||
bool IsFinished(Item* item) const {
|
||||
return item->placed && item != in_progress_item_;
|
||||
}
|
||||
|
||||
// Returns whether the given buffer is being used by the in-progress
|
||||
// instruction.
|
||||
bool IsInUse(BufferId buffer_id) const {
|
||||
if (in_progress_instruction_ == nullptr) {
|
||||
if (in_progress_item_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const std::vector<BufferId>& in_progress_uses =
|
||||
buffers_used_by_instruction_.at(in_progress_instruction_);
|
||||
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
|
||||
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
|
||||
buffer_id) != in_progress_uses.end();
|
||||
}
|
||||
@ -392,14 +465,13 @@ class MemoryUsageTracker {
|
||||
// point.
|
||||
bool IsCurrentlyLive(BufferId buffer_id) const {
|
||||
const Buffer& buffer = buffers_[buffer_id];
|
||||
return (IsPlaced(buffer.defining_instruction) &&
|
||||
return (buffer.defining_instruction->placed &&
|
||||
buffer.unfinished_user_count > 0);
|
||||
}
|
||||
|
||||
// Create a new buffer, add it to buffers_, and return a reference.
|
||||
Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size,
|
||||
std::vector<const HloInstruction*>&& users, bool live_out,
|
||||
bool has_indirect_uses) {
|
||||
Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
|
||||
bool live_out, bool has_indirect_uses) {
|
||||
int buffer_id = buffers_.size();
|
||||
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
|
||||
has_indirect_uses, users,
|
||||
@ -419,19 +491,7 @@ class MemoryUsageTracker {
|
||||
|
||||
// The instruction currently being placed. This value is non-null only
|
||||
// between the calling of BeginInstruction and EndInstruction.
|
||||
const HloInstruction* in_progress_instruction_ = nullptr;
|
||||
|
||||
// The buffers defined by each instruction.
|
||||
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
|
||||
buffers_defined_by_instruction_;
|
||||
|
||||
// The buffers used by each instruction.
|
||||
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
|
||||
buffers_used_by_instruction_;
|
||||
|
||||
// The set of instructions which have been placed. That is, BeginInstruction
|
||||
// has been called with the instruction as an argument.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> placed_instructions_;
|
||||
Item* in_progress_item_ = nullptr;
|
||||
|
||||
// All buffers in the computation.
|
||||
std::vector<Buffer> buffers_;
|
||||
@ -443,22 +503,15 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const InstructionList& instruction_list)
|
||||
: computation_(computation), instruction_list_(instruction_list) {
|
||||
// Iterate through all LogicalBuffers in the computation and gather the
|
||||
// instructions which define them in buffers_defined_by_instruction_ and the
|
||||
// instructions which use them in buffers_used_by_instruction_.
|
||||
for (auto& instruction : computation_->instructions()) {
|
||||
// Initialize empty vectors for defs and uses of each instruction.
|
||||
buffers_used_by_instruction_[instruction.get()];
|
||||
buffers_defined_by_instruction_[instruction.get()];
|
||||
}
|
||||
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
|
||||
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
||||
.CreateFlattenedSet();
|
||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
|
||||
logical_buffer_to_buffer_id;
|
||||
|
||||
for (const HloInstruction* instruction : instruction_list_.instructions()) {
|
||||
for (auto* item = instruction_list_.first(); item != nullptr;
|
||||
item = instruction_list_.next(item)) {
|
||||
const HloInstruction* const instruction = item->instruction;
|
||||
for (const LogicalBuffer* logical_buffer :
|
||||
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
||||
Buffer* buffer;
|
||||
@ -481,22 +534,22 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
|
||||
// Add users of while to Buffer users.
|
||||
bool unused;
|
||||
for (const HloInstruction* user :
|
||||
GetUsers(logical_buffer, points_to_analysis, &unused)) {
|
||||
if (std::find(buffer->users.begin(), buffer->users.end(), user) ==
|
||||
buffer->users.end()) {
|
||||
buffer->users.push_back(user);
|
||||
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &unused)) {
|
||||
if (std::find(buffer->users.begin(), buffer->users.end(),
|
||||
user_item) == buffer->users.end()) {
|
||||
buffer->users.push_back(user_item);
|
||||
buffer->unfinished_user_count++;
|
||||
buffers_used_by_instruction_.at(user).push_back(buffer->id);
|
||||
user_item->buffers_used.push_back(buffer->id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
buffer = &CreateBufferFromLogicalBuffer(
|
||||
logical_buffer, points_to_analysis, size_function,
|
||||
ContainsKey(live_out_set, logical_buffer));
|
||||
buffers_defined_by_instruction_.at(instruction).push_back(buffer->id);
|
||||
for (const HloInstruction* user : buffer->users) {
|
||||
buffers_used_by_instruction_.at(user).push_back(buffer->id);
|
||||
item->buffers_defined.push_back(buffer->id);
|
||||
for (Item* user : buffer->users) {
|
||||
user->buffers_used.push_back(buffer->id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -507,15 +560,16 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
DCHECK(Check());
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
|
||||
Status MemoryUsageTracker::BeginInstruction(Item* item) {
|
||||
const HloInstruction* instruction = item->instruction;
|
||||
VLOG(3) << "BeginInstruction " << instruction->name();
|
||||
TF_RET_CHECK(in_progress_instruction_ == nullptr);
|
||||
in_progress_instruction_ = instruction;
|
||||
TF_RET_CHECK(in_progress_item_ == nullptr);
|
||||
in_progress_item_ = item;
|
||||
|
||||
placed_instructions_.insert(in_progress_instruction_);
|
||||
item->placed = true;
|
||||
|
||||
// All buffers defined by this instruction need memory.
|
||||
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
|
||||
<< " is now live.";
|
||||
memory_usage_ += AllocatedSize(buffer_id);
|
||||
@ -532,11 +586,10 @@ Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::EndInstruction() {
|
||||
TF_RET_CHECK(in_progress_instruction_ != nullptr);
|
||||
VLOG(3) << "EndInstruction " << in_progress_instruction_->name();
|
||||
TF_RET_CHECK(in_progress_item_ != nullptr);
|
||||
VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
|
||||
|
||||
for (BufferId buffer_id :
|
||||
buffers_used_by_instruction_.at(in_progress_instruction_)) {
|
||||
for (BufferId buffer_id : in_progress_item_->buffers_used) {
|
||||
Buffer& buffer = buffers_.at(buffer_id);
|
||||
buffer.unfinished_user_count--;
|
||||
CHECK_GE(buffer.unfinished_user_count, 0)
|
||||
@ -551,8 +604,7 @@ Status MemoryUsageTracker::EndInstruction() {
|
||||
|
||||
// If any buffer defined by this instruction has no uses, then memory can be
|
||||
// reclaimed immediately.
|
||||
for (BufferId buffer_id :
|
||||
buffers_defined_by_instruction_.at(in_progress_instruction_)) {
|
||||
for (BufferId buffer_id : in_progress_item_->buffers_defined) {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
if (buffer.unfinished_user_count == 0) {
|
||||
VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
|
||||
@ -561,7 +613,7 @@ Status MemoryUsageTracker::EndInstruction() {
|
||||
}
|
||||
}
|
||||
|
||||
in_progress_instruction_ = nullptr;
|
||||
in_progress_item_ = nullptr;
|
||||
|
||||
VLOG(3) << " memory usage = " << memory_usage_;
|
||||
VLOG(10) << ToString();
|
||||
@ -571,10 +623,9 @@ Status MemoryUsageTracker::EndInstruction() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
const HloInstruction* instruction) const {
|
||||
CHECK_NE(in_progress_instruction_, nullptr);
|
||||
if (!IsPlaced(instruction) || instruction == in_progress_instruction_) {
|
||||
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
|
||||
CHECK_NE(in_progress_item_, nullptr);
|
||||
if (!item->placed || item == in_progress_item_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -589,7 +640,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
// be live at this program point, so initially set memory_reduced to the
|
||||
// size of its defined values.
|
||||
int64 memory_reduced = 0;
|
||||
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
// Avoid rematerializing instructions with indirect uses as it is difficult
|
||||
// to reason about liveness after rematerializing the instruction.
|
||||
// TODO(b/37714814): Consider rematerialzing instructions with indirect
|
||||
@ -605,7 +656,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
|
||||
// Account for any logical buffers whose live range must be extended across
|
||||
// this program point.
|
||||
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) {
|
||||
for (BufferId buffer_id : item->buffers_used) {
|
||||
if (!IsCurrentlyLive(buffer_id)) {
|
||||
// This logical buffer is used by 'instruction' but is not live at this
|
||||
// program point. Rematerializing 'instruction' will extend the buffer's
|
||||
@ -617,28 +668,23 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
return memory_reduced;
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
HloInstruction* original_instruction, HloInstruction* remat_instruction) {
|
||||
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||
Item* remat_item) {
|
||||
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
||||
<< original_instruction->name()
|
||||
<< ", remat_instruction = " << remat_instruction->name();
|
||||
<< original_item->instruction->name()
|
||||
<< ", remat_instruction = " << remat_item->instruction->name();
|
||||
|
||||
TF_RET_CHECK(in_progress_instruction_ != nullptr);
|
||||
TF_RET_CHECK(IsPlaced(original_instruction));
|
||||
TF_RET_CHECK(!IsPlaced(remat_instruction));
|
||||
CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction));
|
||||
CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction));
|
||||
TF_RET_CHECK(in_progress_item_ != nullptr);
|
||||
TF_RET_CHECK(original_item->placed);
|
||||
TF_RET_CHECK(!remat_item->placed);
|
||||
|
||||
// Construct the list of buffers used and defined by the rematerialization.
|
||||
buffers_defined_by_instruction_[remat_instruction];
|
||||
buffers_used_by_instruction_[remat_instruction] =
|
||||
buffers_used_by_instruction_.at(original_instruction);
|
||||
remat_item->buffers_used = original_item->buffers_used;
|
||||
|
||||
// Account for the additional buffer uses created by the new rematerialization
|
||||
// instruction. Update memory usage if the rematerialization makes a dead
|
||||
// buffer live again.
|
||||
for (BufferId buffer_id :
|
||||
buffers_used_by_instruction_.at(original_instruction)) {
|
||||
for (BufferId buffer_id : original_item->buffers_used) {
|
||||
Buffer& buffer = buffers_.at(buffer_id);
|
||||
if (buffer.unfinished_user_count == 0) {
|
||||
// Buffer used by this instruction was dead, now is alive.
|
||||
@ -646,20 +692,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
}
|
||||
|
||||
buffer.unfinished_user_count++;
|
||||
buffer.users.push_back(remat_instruction);
|
||||
buffer.users.push_back(remat_item);
|
||||
}
|
||||
|
||||
// Create a new set of Buffers defined by the new rematerialization
|
||||
// instruction. Update the internal data structures and memory use to account
|
||||
// for them.
|
||||
for (BufferId old_buffer_id :
|
||||
buffers_defined_by_instruction_.at(original_instruction)) {
|
||||
for (BufferId old_buffer_id : original_item->buffers_defined) {
|
||||
Buffer& old_buffer = buffers_.at(old_buffer_id);
|
||||
|
||||
std::vector<const HloInstruction*> placed_users;
|
||||
std::vector<const HloInstruction*> unplaced_users;
|
||||
for (const HloInstruction* user : old_buffer.users) {
|
||||
if (IsPlaced(user)) {
|
||||
ItemList placed_users;
|
||||
ItemList unplaced_users;
|
||||
for (Item* user : old_buffer.users) {
|
||||
if (user->placed) {
|
||||
CHECK(IsFinished(user));
|
||||
placed_users.push_back(user);
|
||||
} else {
|
||||
@ -672,14 +717,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
||||
// Buffer is now dead.
|
||||
memory_usage_ -= AllocatedSize(old_buffer.id);
|
||||
|
||||
Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction,
|
||||
std::move(unplaced_users));
|
||||
Buffer& new_buffer =
|
||||
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
|
||||
|
||||
buffers_defined_by_instruction_.at(remat_instruction)
|
||||
.push_back(new_buffer.id);
|
||||
for (const HloInstruction* user : new_buffer.users) {
|
||||
std::vector<BufferId>& buffers_used =
|
||||
buffers_used_by_instruction_.at(user);
|
||||
remat_item->buffers_defined.push_back(new_buffer.id);
|
||||
for (Item* user : new_buffer.users) {
|
||||
BufferIdList& buffers_used = user->buffers_used;
|
||||
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
||||
new_buffer.id);
|
||||
}
|
||||
@ -699,13 +742,14 @@ string MemoryUsageTracker::ToString() const {
|
||||
tensorflow::strings::StrAppend(
|
||||
&output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
|
||||
memory_usage(), " bytes)");
|
||||
for (const HloInstruction* instruction : instruction_list_.instructions()) {
|
||||
string inprogress =
|
||||
instruction == in_progress_instruction_ ? " in-progress" : "";
|
||||
string placed = IsPlaced(instruction) ? " placed" : "";
|
||||
for (auto* item = instruction_list_.first(); item != nullptr;
|
||||
item = instruction_list_.next(item)) {
|
||||
const HloInstruction* instruction = item->instruction;
|
||||
string inprogress = item == in_progress_item_ ? " in-progress" : "";
|
||||
string placed = item->placed ? " placed" : "";
|
||||
tensorflow::strings::StrAppend(&output, " ", instruction->name(),
|
||||
inprogress, placed, "\n Defines:\n");
|
||||
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
|
||||
for (BufferId buffer_id : item->buffers_defined) {
|
||||
const Buffer& buffer = buffers_[buffer_id];
|
||||
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
|
||||
tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
|
||||
@ -713,7 +757,7 @@ string MemoryUsageTracker::ToString() const {
|
||||
" unfinished uses\n");
|
||||
}
|
||||
tensorflow::strings::StrAppend(&output, " Uses:\n");
|
||||
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) {
|
||||
for (BufferId buffer_id : item->buffers_used) {
|
||||
tensorflow::strings::StrAppend(&output, " ",
|
||||
buffers_[buffer_id].ToString(), "\n");
|
||||
}
|
||||
@ -722,14 +766,14 @@ string MemoryUsageTracker::ToString() const {
|
||||
}
|
||||
|
||||
bool MemoryUsageTracker::Check() const {
|
||||
auto elements_are_unique = [](const std::vector<BufferId>& vec) {
|
||||
auto elements_are_unique = [](const BufferIdList& vec) {
|
||||
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
|
||||
};
|
||||
|
||||
// Verify buffers_defined_by_instruction_.
|
||||
// Verify buffers_defined per instruction.
|
||||
for (auto& instruction : computation_->instructions()) {
|
||||
const std::vector<BufferId>& defined_buffers =
|
||||
buffers_defined_by_instruction_.at(instruction.get());
|
||||
const BufferIdList& defined_buffers =
|
||||
instruction_list_.GetItem(instruction.get())->buffers_defined;
|
||||
CHECK(elements_are_unique(defined_buffers))
|
||||
<< "Instruction " << instruction->name()
|
||||
<< " does not have unique defined buffers: "
|
||||
@ -740,7 +784,7 @@ bool MemoryUsageTracker::Check() const {
|
||||
});
|
||||
|
||||
for (const Buffer& buffer : buffers_) {
|
||||
if (buffer.defining_instruction == instruction.get()) {
|
||||
if (buffer.defining_instruction->instruction == instruction.get()) {
|
||||
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
|
||||
buffer.id) != defined_buffers.end())
|
||||
<< "Instruction " << instruction->name()
|
||||
@ -749,10 +793,10 @@ bool MemoryUsageTracker::Check() const {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify buffers_used_by_instruction_.
|
||||
// Verify buffers_used per instruction.
|
||||
for (auto& instruction : computation_->instructions()) {
|
||||
const std::vector<BufferId>& used_buffers =
|
||||
buffers_used_by_instruction_.at(instruction.get());
|
||||
const BufferIdList& used_buffers =
|
||||
instruction_list_.GetItem(instruction.get())->buffers_used;
|
||||
CHECK(elements_are_unique(used_buffers))
|
||||
<< "Instruction " << instruction->name()
|
||||
<< " does not have unique used buffers: "
|
||||
@ -764,13 +808,12 @@ bool MemoryUsageTracker::Check() const {
|
||||
}
|
||||
for (const Buffer& buffer : buffers_) {
|
||||
int64 unfinished_uses = 0;
|
||||
for (const HloInstruction* user : buffer.users) {
|
||||
const std::vector<BufferId>& used_buffers =
|
||||
buffers_used_by_instruction_.at(user);
|
||||
for (Item* user : buffer.users) {
|
||||
const BufferIdList& used_buffers = user->buffers_used;
|
||||
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
|
||||
used_buffers.end())
|
||||
<< "Instruction " << user->name() << " used buffers is missing "
|
||||
<< buffer.ToString();
|
||||
<< "Instruction " << user->instruction->name()
|
||||
<< " used buffers is missing " << buffer.ToString();
|
||||
if (!IsFinished(user)) {
|
||||
unfinished_uses++;
|
||||
}
|
||||
@ -785,8 +828,8 @@ bool MemoryUsageTracker::Check() const {
|
||||
// The while instruction reuses its input buffers as output buffers so
|
||||
// don't double count its buffers if it is currently executing.
|
||||
if (IsCurrentlyLive(buffer.id) &&
|
||||
!(buffer.defining_instruction == in_progress_instruction_ &&
|
||||
in_progress_instruction_->opcode() == HloOpcode::kWhile)) {
|
||||
!(buffer.defining_instruction == in_progress_item_ &&
|
||||
in_progress_item_->instruction->opcode() == HloOpcode::kWhile)) {
|
||||
live_size += AllocatedSize(buffer.id);
|
||||
}
|
||||
}
|
||||
@ -830,26 +873,26 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
||||
// candidate which reduce memory use at the program point of the current
|
||||
// instruction as indicated by memory_tracker. nullptr is returned if no
|
||||
// candidate can be found.
|
||||
HloInstruction* PickRematerializationCandidate(
|
||||
const MemoryUsageTracker& memory_tracker,
|
||||
const InstructionList& instruction_list,
|
||||
const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist,
|
||||
int64 memory_limit_bytes) {
|
||||
HloInstruction* best = nullptr;
|
||||
Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
|
||||
const InstructionList& instruction_list,
|
||||
int64 memory_limit_bytes) {
|
||||
Item* best_item = nullptr;
|
||||
int64 best_cost = 0;
|
||||
|
||||
// TODO(b/35244891): This is currently quadratic in the number of HLO
|
||||
// instructions.
|
||||
for (HloInstruction* candidate : instruction_list.instructions()) {
|
||||
if (!memory_tracker.IsPlaced(candidate)) {
|
||||
// Only iterate up to the currently placed instruction as indicated by
|
||||
// memory_tracker. We are trying to reduce memory usage at the placed
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
item = instruction_list.next(item)) {
|
||||
if (!item->placed) {
|
||||
// Only iterate up to the currently placed instruction.
|
||||
// We are trying to reduce memory usage at the placed
|
||||
// instruction so rematerializing later values is of no benefit.
|
||||
break;
|
||||
}
|
||||
HloInstruction* candidate = item->instruction;
|
||||
VLOG(5) << "considering rematerialization candidate " << candidate->name();
|
||||
|
||||
if (ContainsKey(blacklist, candidate)) {
|
||||
if (item->blacklisted) {
|
||||
// Skip instructions on the blacklist to avoid infinite loops of
|
||||
// rematerializing the same instruction(s) repeatedly.
|
||||
VLOG(5) << "candidate " << candidate->name()
|
||||
@ -864,7 +907,7 @@ HloInstruction* PickRematerializationCandidate(
|
||||
}
|
||||
|
||||
const int64 memory_reduced =
|
||||
memory_tracker.MemoryReducedIfRematerialized(candidate);
|
||||
memory_tracker.MemoryReducedIfRematerialized(item);
|
||||
|
||||
if (memory_reduced <= 0) {
|
||||
VLOG(5) << "candidate " << candidate->name()
|
||||
@ -878,13 +921,13 @@ HloInstruction* PickRematerializationCandidate(
|
||||
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
|
||||
<< memory_reduced << ", cost per byte " << cost;
|
||||
|
||||
if (best == nullptr || cost < best_cost) {
|
||||
if (best_item == nullptr || cost < best_cost) {
|
||||
VLOG(5) << "candidate " << candidate->name() << " now best";
|
||||
best = candidate;
|
||||
best_item = item;
|
||||
best_cost = cost;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
return best_item;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -896,8 +939,10 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
||||
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
||||
instruction_list);
|
||||
int64 peak_memory = tracker.memory_usage();
|
||||
for (const HloInstruction* instruction : order) {
|
||||
TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction));
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
item = instruction_list.next(item)) {
|
||||
const HloInstruction* instruction = item->instruction;
|
||||
TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
|
||||
TF_ASSIGN_OR_RETURN(int64 callee_usage,
|
||||
CalledComputationsMemoryUsage(instruction));
|
||||
peak_memory =
|
||||
@ -939,11 +984,6 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
*points_to_analysis_, instruction_list);
|
||||
bool changed = false;
|
||||
|
||||
// To avoid an infinite loop rematerializing the same set of instructions ad
|
||||
// infinitum, keep a blacklist of instructions which should not be
|
||||
// rematerialized.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> blacklist;
|
||||
|
||||
// If the rematerialization makes the source instruction dead, then the
|
||||
// rematerialization is added to 'remat_move_instructions' (the
|
||||
// rematerialization is essentially a move). If the next rematerialization of
|
||||
@ -967,17 +1007,17 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
// (program point) if memory_usage exceeds the specified limit then
|
||||
// rematerialize HLO instructions until memory_usage is reduced.
|
||||
int64 instruction_index = 0;
|
||||
for (auto list_it = instruction_list.instructions().begin();
|
||||
list_it != instruction_list.instructions().end(); ++list_it) {
|
||||
HloInstruction* instruction = *list_it;
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
item = instruction_list.next(item)) {
|
||||
const HloInstruction* instruction = item->instruction;
|
||||
TF_ASSIGN_OR_RETURN(int64 callee_usage,
|
||||
CalledComputationsMemoryUsage(instruction));
|
||||
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(instruction));
|
||||
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
|
||||
|
||||
VLOG(2) << "Program point at " << instruction->name()
|
||||
<< ", memory usage = " << memory_tracker.memory_usage()
|
||||
<< ", callee usage = " << callee_usage << ", [" << instruction_index
|
||||
<< "/" << instruction_list.instructions().size() << "]";
|
||||
<< "/" << instruction_list.size() << "]";
|
||||
instruction_index++;
|
||||
|
||||
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
|
||||
@ -987,10 +1027,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
callee_usage)
|
||||
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
|
||||
|
||||
HloInstruction* best = PickRematerializationCandidate(
|
||||
memory_tracker, instruction_list, blacklist, memory_limit_bytes);
|
||||
Item* best_item = PickRematerializationCandidate(
|
||||
memory_tracker, instruction_list, memory_limit_bytes);
|
||||
|
||||
if (best == nullptr) {
|
||||
if (best_item == nullptr) {
|
||||
VLOG(3) << "Unable to find rematerialization candidate at program "
|
||||
"point "
|
||||
<< instruction->name() << ". Memory usage = "
|
||||
@ -999,13 +1039,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
break;
|
||||
}
|
||||
|
||||
HloInstruction* best = best_item->instruction;
|
||||
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
|
||||
<< memory_tracker.MemoryReducedIfRematerialized(best) << ")";
|
||||
<< memory_tracker.MemoryReducedIfRematerialized(best_item) << ")";
|
||||
changed = true;
|
||||
remat_count++;
|
||||
|
||||
HloInstruction* remat =
|
||||
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
|
||||
Item* remat_item = instruction_list.CreateItem(remat);
|
||||
|
||||
// Replace each remaining use of 'best' with the rematerialization.
|
||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
||||
@ -1019,22 +1061,28 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
|
||||
// Account for the rematerialization in the memory tracker.
|
||||
TF_RETURN_IF_ERROR(
|
||||
memory_tracker.AddRematerializedInstruction(best, remat));
|
||||
memory_tracker.AddRematerializedInstruction(best_item, remat_item));
|
||||
|
||||
// Insert rematerialized instruction right before the earliest unplaced
|
||||
// use of the instruction *and* the earliest unplaced last use of any
|
||||
// operands of remat. Unplaced uses of the remat's operands are included
|
||||
// because we don't want to extend the live range of remat's operands as
|
||||
// this could increase memory usage.
|
||||
std::vector<HloInstruction*> place_before = remat->users();
|
||||
ItemList place_before;
|
||||
for (auto user : remat->users()) {
|
||||
place_before.push_back(instruction_list.GetItem(user));
|
||||
}
|
||||
for (auto* operand : remat->operands()) {
|
||||
for (auto* operand_user : operand->users()) {
|
||||
if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) {
|
||||
place_before.push_back(operand_user);
|
||||
if (operand_user != remat) {
|
||||
Item* operand_user_item = instruction_list.GetItem(operand_user);
|
||||
if (!operand_user_item->placed) {
|
||||
place_before.push_back(operand_user_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
instruction_list.InsertBeforeInstructions(remat, place_before);
|
||||
instruction_list.InsertBeforeInstructions(remat_item, place_before);
|
||||
|
||||
// If the rematerialized instruction is dead then rematerialization is
|
||||
// essentially a move. Don't delete the instruction now because we don't
|
||||
@ -1048,7 +1096,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
// instruction it was a copying of. Now 'remat' is a rematerialization
|
||||
// of 'best' and kills 'best'. Stop rematerializing this instruction
|
||||
// to avoid an infinite loop.
|
||||
blacklist.insert(remat);
|
||||
instruction_list.Blacklist(remat);
|
||||
}
|
||||
remat_move_instructions.insert(remat);
|
||||
} else {
|
||||
@ -1116,10 +1164,13 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
computation_peak_memory_.at(computation) = peak_memory;
|
||||
|
||||
// Update order to include rematerialized instructions.
|
||||
sequence->at(computation)
|
||||
.assign(instruction_list.instructions().begin(),
|
||||
instruction_list.instructions().end());
|
||||
|
||||
auto& dst = sequence->at(computation);
|
||||
dst.clear();
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
item = instruction_list.next(item)) {
|
||||
const HloInstruction* instruction = item->instruction;
|
||||
dst.push_back(instruction);
|
||||
}
|
||||
rematerialized_computations_.insert(computation);
|
||||
|
||||
instructions_rematerialized_ += remat_count;
|
||||
|
Loading…
Reference in New Issue
Block a user