Improve performance of compilation by ~8% by speeding up the
hlo rematerialization pass. Changes: . Wrap each HloInstruction* inside an Item structure that keeps associated data. This allows us to get rid of a bunch of hash tables indexed by HloInstruction*. * Switch to an intrusive linked list (instead of std::list) so that we can avoid a hash table that maps to std::list::iterator. * Use inlined vector in a few places. PiperOrigin-RevId: 163848365
This commit is contained in:
parent
6d77a01293
commit
66f1485424
@ -47,6 +47,12 @@ namespace xla {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Potential optimizations:
|
||||||
|
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
|
||||||
|
// of candidates.
|
||||||
|
// . Cache IsRematerializable in Item? Only correct if control
|
||||||
|
// predecessors and successors don't change.
|
||||||
|
|
||||||
// Returns true if the given instruction is rematerializable.
|
// Returns true if the given instruction is rematerializable.
|
||||||
bool IsRematerializable(const HloInstruction* instruction) {
|
bool IsRematerializable(const HloInstruction* instruction) {
|
||||||
// Conservatively, don't rematerialize instruction with control
|
// Conservatively, don't rematerialize instruction with control
|
||||||
@ -79,126 +85,202 @@ bool IsRematerializable(const HloInstruction* instruction) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type holding a unique identifier for each Buffer object.
|
||||||
|
using BufferId = int64;
|
||||||
|
using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
|
||||||
|
|
||||||
|
// We wrap HloInstruction* with an Item that holds auxiliary
|
||||||
|
// per-instruction state.
|
||||||
|
struct Item {
|
||||||
|
HloInstruction* instruction;
|
||||||
|
|
||||||
|
// True once the instruction is marked as placed (when BeginInstruction
|
||||||
|
// has been called for this instruction).
|
||||||
|
bool placed = false;
|
||||||
|
|
||||||
|
// To avoid an infinite loop rematerializing the same set of
|
||||||
|
// instructions ad infinitum, keep a blacklist of instructions
|
||||||
|
// which should not be rematerialized.
|
||||||
|
bool blacklisted = false;
|
||||||
|
|
||||||
|
// The buffers defined by this instruction.
|
||||||
|
BufferIdList buffers_defined;
|
||||||
|
|
||||||
|
// The buffers used by this instruction.
|
||||||
|
BufferIdList buffers_used;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class InstructionList;
|
||||||
|
|
||||||
|
// Items are arranged in a doubly linked list.
|
||||||
|
Item* next;
|
||||||
|
Item* prev;
|
||||||
|
|
||||||
|
// List is ordered by position, which can however be duplicated as
|
||||||
|
// new instructions are inserted. See InsertBeforeInstructions
|
||||||
|
// comment for details.
|
||||||
|
int64 position;
|
||||||
|
};
|
||||||
|
|
||||||
|
using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
|
||||||
|
|
||||||
// Class which maintains an ordered list of instructions with fast insertion
|
// Class which maintains an ordered list of instructions with fast insertion
|
||||||
// before arbitrary elements.
|
// before arbitrary elements.
|
||||||
class InstructionList {
|
class InstructionList {
|
||||||
public:
|
public:
|
||||||
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
|
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
|
||||||
int64 position = 0;
|
int64 position = 0;
|
||||||
|
Item* last = nullptr;
|
||||||
for (const HloInstruction* inst : order) {
|
for (const HloInstruction* inst : order) {
|
||||||
instructions_.push_back(const_cast<HloInstruction*>(inst));
|
// Add a new item to the linked list.
|
||||||
instruction_iterators_.insert({const_cast<HloInstruction*>(inst),
|
Item* item = new Item;
|
||||||
std::next(instructions_.end(), -1)});
|
item->next = nullptr;
|
||||||
|
item->prev = last;
|
||||||
|
if (last == nullptr) {
|
||||||
|
first_ = item;
|
||||||
|
} else {
|
||||||
|
last->next = item;
|
||||||
|
}
|
||||||
|
last = item;
|
||||||
|
|
||||||
// Initially position numbers are uniquely assigned in order. Later as
|
// Initially position numbers are uniquely assigned in order. Later as
|
||||||
// instructions are added with InsertBefore* methods, some instructions
|
// instructions are added with InsertBefore* methods, some instructions
|
||||||
// may have duplicate position numbers, but the values will be guaranteed
|
// may have duplicate position numbers, but the values will be guaranteed
|
||||||
// to be monotonically increasing through the list, and so is still useful
|
// to be monotonically increasing through the list, and so is still useful
|
||||||
// for quickly(-ish) determining the order of arbitrary instructions in
|
// for quickly(-ish) determining the order of arbitrary instructions in
|
||||||
// the list.
|
// the list.
|
||||||
position_number_[inst] = position;
|
item->instruction = const_cast<HloInstruction*>(inst);
|
||||||
first_at_position_[position] = inst;
|
item->position = position;
|
||||||
position++;
|
position++;
|
||||||
|
|
||||||
|
item_map_[inst] = item;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the list of instructions.
|
~InstructionList() {
|
||||||
const std::list<HloInstruction*>& instructions() const {
|
for (Item* item = first_; item != nullptr;) {
|
||||||
return instructions_;
|
Item* next = item->next;
|
||||||
|
delete item;
|
||||||
|
item = next;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert instruction 'to_insert' immediately before instruction 'before' in
|
size_t size() const { return item_map_.size(); }
|
||||||
// the list.
|
|
||||||
void InsertBefore(HloInstruction* to_insert, HloInstruction* before) {
|
// For ordered iteration over items.
|
||||||
VLOG(3) << "InsertBefore: " << to_insert->name() << " before "
|
// for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
|
||||||
<< before->name();
|
Item* first() const { return first_; }
|
||||||
auto it = instruction_iterators_.find(before);
|
Item* next(Item* item) const { return item->next; }
|
||||||
CHECK(it != instruction_iterators_.end());
|
|
||||||
instruction_iterators_.insert(
|
// Creates an Item for the given instruction, but doesn't add it to the list.
|
||||||
{to_insert, instructions_.insert(it->second, to_insert)});
|
// (Use InsertBeforeInstructions to add the Item to the list.)
|
||||||
// Assign the same position number to the newly added instruction as
|
Item* CreateItem(HloInstruction* inst) {
|
||||||
// 'before'. This guarantees monotonicity of the position numbers, but not
|
Item* item = new Item;
|
||||||
// uniqueness.
|
item->instruction = inst;
|
||||||
int64 pos = position_number_.at(before);
|
CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice";
|
||||||
position_number_[to_insert] = pos;
|
return item;
|
||||||
if (first_at_position_.at(pos) == before) {
|
}
|
||||||
first_at_position_[pos] = to_insert;
|
|
||||||
}
|
// Return the Item corresponding to inst.
|
||||||
|
Item* GetItem(const HloInstruction* inst) const {
|
||||||
|
auto iter = item_map_.find(inst);
|
||||||
|
CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
|
||||||
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert instruction 'to_insert' immediately before the earliest instruction
|
// Insert instruction 'to_insert' immediately before the earliest instruction
|
||||||
// in 'before_instructions'.
|
// in 'before_instructions'.
|
||||||
|
//
|
||||||
|
// Each instruction gets a non-decreasing ordinal number. We use this to let
|
||||||
|
// InsertBeforeInstructions quickly insert an instruction before the earliest
|
||||||
|
// instruction in a set of instructions. If position_number_[a] <
|
||||||
|
// position_number_[b] then 'a' comes before 'b' in the list. If the position
|
||||||
|
// numbers are the same then nothing can be said about their order without
|
||||||
|
// examining the list.
|
||||||
|
//
|
||||||
|
// On object construction this ordinal is precisely the instruction's index
|
||||||
|
// in the list. Later, instructions inserted via InsertBefore receive
|
||||||
|
// duplicate values. However, monotonicity is preserved.
|
||||||
void InsertBeforeInstructions(
|
void InsertBeforeInstructions(
|
||||||
HloInstruction* to_insert,
|
Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
|
||||||
tensorflow::gtl::ArraySlice<HloInstruction*> before_instructions) {
|
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
|
||||||
VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {"
|
<< " before {"
|
||||||
<< tensorflow::str_util::Join(
|
<< tensorflow::str_util::Join(before_instructions, ", ",
|
||||||
before_instructions, ", ",
|
[](string* out, Item* item) {
|
||||||
[](string* out, HloInstruction* inst) {
|
tensorflow::strings::StrAppend(
|
||||||
tensorflow::strings::StrAppend(out, inst->name());
|
out, item->instruction->name());
|
||||||
})
|
})
|
||||||
<< "}";
|
<< "}";
|
||||||
|
|
||||||
// Find the minimal position number of any instruction in
|
// Find the minimal position number of any instruction in
|
||||||
// 'before_instructions'.
|
// 'before_instructions'.
|
||||||
CHECK(!before_instructions.empty());
|
CHECK(!before_instructions.empty());
|
||||||
int64 min_position_number = std::numeric_limits<int64>::max();
|
Item* min_position_item = nullptr;
|
||||||
for (const HloInstruction* instruction : before_instructions) {
|
for (Item* item : before_instructions) {
|
||||||
min_position_number =
|
if (min_position_item == nullptr ||
|
||||||
std::min(min_position_number, position_number_.at(instruction));
|
item->position < min_position_item->position) {
|
||||||
|
min_position_item = item;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Because more than one instruction in 'before_instructions' may have a
|
// Because more than one instruction in 'before_instructions' may have a
|
||||||
// position number of 'min_position_number', find the first such instruction
|
// position number of 'min_position_number', find the first such instruction
|
||||||
// with position number 'min_position_number'.
|
// with position number 'min_position_number'.
|
||||||
for (auto it = instruction_iterators_.at(
|
|
||||||
first_at_position_.at(min_position_number));
|
// First find first instruction with the min position.
|
||||||
it != instructions_.end() &&
|
while (min_position_item->prev != nullptr &&
|
||||||
position_number_.at(*it) == min_position_number;
|
min_position_item->position == min_position_item->prev->position) {
|
||||||
++it) {
|
min_position_item = min_position_item->prev;
|
||||||
if (std::find(before_instructions.begin(), before_instructions.end(),
|
|
||||||
*it) != before_instructions.end()) {
|
|
||||||
return InsertBefore(to_insert, *it);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
LOG(FATAL) << "Expected to find instruction in before_instructions with "
|
|
||||||
"position number "
|
// Now scan forwards until we find one of the before_instructions.
|
||||||
<< min_position_number;
|
while (std::find(before_instructions.begin(), before_instructions.end(),
|
||||||
|
min_position_item) == before_instructions.end()) {
|
||||||
|
min_position_item = min_position_item->next;
|
||||||
|
}
|
||||||
|
return InsertBefore(to_insert, min_position_item);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Blacklist(const HloInstruction* inst) {
|
||||||
|
GetItem(inst)->blacklisted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// List of instructions.
|
// Insert instruction 'item' immediately before 'before' in the list.
|
||||||
std::list<HloInstruction*> instructions_;
|
void InsertBefore(Item* item, Item* before) {
|
||||||
|
VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
|
||||||
|
<< before->instruction->name();
|
||||||
|
// Insert new item into linked list.
|
||||||
|
item->prev = before->prev;
|
||||||
|
item->next = before;
|
||||||
|
before->prev = item;
|
||||||
|
if (item->prev != nullptr) {
|
||||||
|
item->prev->next = item;
|
||||||
|
} else {
|
||||||
|
first_ = item;
|
||||||
|
}
|
||||||
|
|
||||||
// Iterators for each instruction in the list.
|
// Assign the same position number to the newly added instruction as
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*,
|
// 'before'. This guarantees monotonicity of the position numbers, but not
|
||||||
std::list<HloInstruction*>::iterator>
|
// uniqueness.
|
||||||
instruction_iterators_;
|
item->position = before->position;
|
||||||
|
}
|
||||||
|
|
||||||
// A number assigned to each instruction which increases monotonically through
|
Item* first_;
|
||||||
// 'instructions_'. Used to facilitate fast insertion of an instruction before
|
|
||||||
// the earliest instruction in a set of instructions
|
|
||||||
// (InsertBeforeInstructions) by enabling fast-ish ordering queries between
|
|
||||||
// instructions. If position_number_[a] < position_number_[b] then 'a' comes
|
|
||||||
// before 'b' in the list. If the position numbers are the same then nothing
|
|
||||||
// can be said about their order without examining the list.
|
|
||||||
//
|
|
||||||
// On object construction this value is precisely the instruction's ordinal
|
|
||||||
// position in the list. Instructions inserted via InsertBefore receive
|
|
||||||
// duplicate values. However, monotonicity is preserved.
|
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> position_number_;
|
|
||||||
|
|
||||||
// The first instruction in the list assigned a particular position number.
|
// Item for each instruction.
|
||||||
tensorflow::gtl::FlatMap<int64, const HloInstruction*> first_at_position_;
|
tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Return the HloInstructions which use the given LogicalBuffer. Sets
|
// Return the items which use the given LogicalBuffer. Sets
|
||||||
// has_indirect_users to whether any of the uses is indirect. A use is indirect
|
// has_indirect_users to whether any of the uses is indirect. A use is indirect
|
||||||
// if the instruction defining logical_buffer is not an operand of the use. This
|
// if the instruction defining logical_buffer is not an operand of the use. This
|
||||||
// can happen via buffer aliasing (eg, tuples).
|
// can happen via buffer aliasing (eg, tuples).
|
||||||
std::vector<const HloInstruction*> GetUsers(
|
ItemList GetUsers(const InstructionList& instruction_list,
|
||||||
const LogicalBuffer* logical_buffer,
|
const LogicalBuffer* logical_buffer,
|
||||||
const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) {
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
std::vector<const HloInstruction*> users;
|
bool* has_indirect_users) {
|
||||||
|
ItemList users;
|
||||||
// To identify uses iterate through all HloInstruction users of the
|
// To identify uses iterate through all HloInstruction users of the
|
||||||
// BufferAliases of the logical buffer.
|
// BufferAliases of the logical buffer.
|
||||||
*has_indirect_users = false;
|
*has_indirect_users = false;
|
||||||
@ -219,8 +301,9 @@ std::vector<const HloInstruction*> GetUsers(
|
|||||||
}
|
}
|
||||||
// A buffer may be used by the instruction via more than one alias. For
|
// A buffer may be used by the instruction via more than one alias. For
|
||||||
// example, a buffer which appears in more than one element of a tuple.
|
// example, a buffer which appears in more than one element of a tuple.
|
||||||
if (std::find(users.begin(), users.end(), user) == users.end()) {
|
Item* user_item = instruction_list.GetItem(user);
|
||||||
users.push_back(user);
|
if (std::find(users.begin(), users.end(), user_item) == users.end()) {
|
||||||
|
users.push_back(user_item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,7 +327,7 @@ class MemoryUsageTracker {
|
|||||||
// EndInstruction) to accurately model memory usage. At BeginInstruction the
|
// EndInstruction) to accurately model memory usage. At BeginInstruction the
|
||||||
// memory for the output value(s) of the current instruction is allocated. At
|
// memory for the output value(s) of the current instruction is allocated. At
|
||||||
// EndInstruction memory for dead operand(s) is freed.
|
// EndInstruction memory for dead operand(s) is freed.
|
||||||
Status BeginInstruction(const HloInstruction* instruction);
|
Status BeginInstruction(Item* item);
|
||||||
|
|
||||||
// Finishes the placement of the current instruction. This frees any dead
|
// Finishes the placement of the current instruction. This frees any dead
|
||||||
// operands or dead result of the instruction. This must be called after
|
// operands or dead result of the instruction. This must be called after
|
||||||
@ -253,40 +336,31 @@ class MemoryUsageTracker {
|
|||||||
|
|
||||||
// Returns the number of bytes that the current memory usage will be reduced
|
// Returns the number of bytes that the current memory usage will be reduced
|
||||||
// if the given instruction is rematerialized.
|
// if the given instruction is rematerialized.
|
||||||
int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const;
|
int64 MemoryReducedIfRematerialized(Item* item) const;
|
||||||
|
|
||||||
// Adjusts memory usage to account for the rematerialization of
|
// Adjusts memory usage to account for the rematerialization of
|
||||||
// original_instruction for all remaining unplaced uses. The rematerialization
|
// original_item for all remaining unplaced uses. The rematerialization
|
||||||
// is remat_instruction. This method should be called after the HLO graph has
|
// is remat_item. This method should be called after the HLO graph has
|
||||||
// been transformed (rematerialization instruction created and connected to
|
// been transformed (rematerialization instruction created and connected to
|
||||||
// uses).
|
// uses).
|
||||||
Status AddRematerializedInstruction(HloInstruction* original_instruction,
|
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
|
||||||
HloInstruction* remat_instruction);
|
|
||||||
|
|
||||||
// Returns whether the given instruction has been placed (BeginInstruction
|
// Returns whether the given instruction has been placed (BeginInstruction
|
||||||
// has been called with 'instruction' as the argument).
|
// has been called with 'instruction' as the argument).
|
||||||
bool IsPlaced(const HloInstruction* instruction) const {
|
bool IsPlaced(const HloInstruction* instruction) const {
|
||||||
return ContainsKey(placed_instructions_, instruction);
|
return instruction_list_.GetItem(instruction)->placed;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the current memory usage. This is the sum of sizes of all live
|
// Returns the current memory usage. This is the sum of sizes of all live
|
||||||
// values.
|
// values.
|
||||||
int64 memory_usage() const { return memory_usage_; }
|
int64 memory_usage() const { return memory_usage_; }
|
||||||
|
|
||||||
// Returns the current instruction being placed.
|
|
||||||
const HloInstruction* in_progress_instruction() const {
|
|
||||||
return in_progress_instruction_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check invariants of the data structure. This is expensive to call.
|
// Check invariants of the data structure. This is expensive to call.
|
||||||
bool Check() const;
|
bool Check() const;
|
||||||
|
|
||||||
string ToString() const;
|
string ToString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Type holding a unique identifier for each Buffer object.
|
|
||||||
using BufferId = int64;
|
|
||||||
|
|
||||||
// A Buffer represents a single LogicalBuffer in the computation including
|
// A Buffer represents a single LogicalBuffer in the computation including
|
||||||
// various metadata useful for tracking liveness of the value. A LogicalBuffer
|
// various metadata useful for tracking liveness of the value. A LogicalBuffer
|
||||||
// is not used directly because the HLO graph is transformed and
|
// is not used directly because the HLO graph is transformed and
|
||||||
@ -298,7 +372,7 @@ class MemoryUsageTracker {
|
|||||||
const BufferId id;
|
const BufferId id;
|
||||||
|
|
||||||
// The instruction which defines this buffer.
|
// The instruction which defines this buffer.
|
||||||
const HloInstruction* defining_instruction;
|
Item* defining_instruction;
|
||||||
|
|
||||||
// The materialized size of the buffer in bytes.
|
// The materialized size of the buffer in bytes.
|
||||||
const int64 size;
|
const int64 size;
|
||||||
@ -312,16 +386,17 @@ class MemoryUsageTracker {
|
|||||||
bool has_indirect_uses;
|
bool has_indirect_uses;
|
||||||
|
|
||||||
// The instructions which use this buffer.
|
// The instructions which use this buffer.
|
||||||
std::vector<const HloInstruction*> users;
|
ItemList users;
|
||||||
|
|
||||||
// The number of users (HloInstructions) of this buffer which have not yet
|
// The number of users (HloInstructions) of this buffer which have not yet
|
||||||
// been placed in the sequence.
|
// been placed in the sequence.
|
||||||
int64 unfinished_user_count;
|
int64 unfinished_user_count;
|
||||||
|
|
||||||
string ToString() const {
|
string ToString() const {
|
||||||
return tensorflow::strings::StrCat("Buffer ", id, " (defined by ",
|
return tensorflow::strings::StrCat(
|
||||||
defining_instruction->name(),
|
"Buffer ", id, " (defined by ",
|
||||||
", size ", size, " bytes)");
|
defining_instruction->instruction->name(), ", size ", size,
|
||||||
|
" bytes)");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -333,25 +408,24 @@ class MemoryUsageTracker {
|
|||||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||||
bool live_out) {
|
bool live_out) {
|
||||||
bool has_indirect_uses = false;
|
bool has_indirect_uses = false;
|
||||||
std::vector<const HloInstruction*> users =
|
ItemList users = GetUsers(instruction_list_, logical_buffer,
|
||||||
GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses);
|
points_to_analysis, &has_indirect_uses);
|
||||||
return NewBuffer(logical_buffer->instruction(),
|
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
||||||
size_function(logical_buffer->shape()), std::move(users),
|
size_function(logical_buffer->shape()), std::move(users),
|
||||||
live_out, has_indirect_uses);
|
live_out, has_indirect_uses);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new buffer representing a rematerialization of given buffer for
|
// Create a new buffer representing a rematerialization of given buffer for
|
||||||
// the given uses.
|
// the given uses.
|
||||||
Buffer& RematerializeBuffer(
|
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
|
||||||
const Buffer& original_buffer, const HloInstruction* remat_instruction,
|
ItemList&& rematerialized_uses) {
|
||||||
std::vector<const HloInstruction*>&& rematerialized_uses) {
|
CHECK(original_buffer.defining_instruction->placed);
|
||||||
CHECK(IsPlaced(original_buffer.defining_instruction));
|
|
||||||
CHECK(!original_buffer.has_indirect_uses);
|
CHECK(!original_buffer.has_indirect_uses);
|
||||||
CHECK(!original_buffer.live_out);
|
CHECK(!original_buffer.live_out);
|
||||||
for (const HloInstruction* use : rematerialized_uses) {
|
for (Item* use : rematerialized_uses) {
|
||||||
CHECK(!IsPlaced(use));
|
CHECK(!use->placed);
|
||||||
}
|
}
|
||||||
return NewBuffer(remat_instruction, original_buffer.size,
|
return NewBuffer(remat_item, original_buffer.size,
|
||||||
std::move(rematerialized_uses), /*live_out=*/false,
|
std::move(rematerialized_uses), /*live_out=*/false,
|
||||||
/*has_indirect_uses=*/false);
|
/*has_indirect_uses=*/false);
|
||||||
}
|
}
|
||||||
@ -362,7 +436,7 @@ class MemoryUsageTracker {
|
|||||||
// different computation.
|
// different computation.
|
||||||
int64 AllocatedSize(BufferId buffer_id) const {
|
int64 AllocatedSize(BufferId buffer_id) const {
|
||||||
const Buffer& buffer = buffers_.at(buffer_id);
|
const Buffer& buffer = buffers_.at(buffer_id);
|
||||||
HloOpcode def_opcode = buffer.defining_instruction->opcode();
|
HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
|
||||||
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
|
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
@ -372,18 +446,17 @@ class MemoryUsageTracker {
|
|||||||
|
|
||||||
// Returns true if BeginInstruction and EndInstruction has been called for the
|
// Returns true if BeginInstruction and EndInstruction has been called for the
|
||||||
// given instruction.
|
// given instruction.
|
||||||
bool IsFinished(const HloInstruction* instruction) const {
|
bool IsFinished(Item* item) const {
|
||||||
return IsPlaced(instruction) && instruction != in_progress_instruction_;
|
return item->placed && item != in_progress_item_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the given buffer is being used by the in-progress
|
// Returns whether the given buffer is being used by the in-progress
|
||||||
// instruction.
|
// instruction.
|
||||||
bool IsInUse(BufferId buffer_id) const {
|
bool IsInUse(BufferId buffer_id) const {
|
||||||
if (in_progress_instruction_ == nullptr) {
|
if (in_progress_item_ == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const std::vector<BufferId>& in_progress_uses =
|
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
|
||||||
buffers_used_by_instruction_.at(in_progress_instruction_);
|
|
||||||
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
|
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
|
||||||
buffer_id) != in_progress_uses.end();
|
buffer_id) != in_progress_uses.end();
|
||||||
}
|
}
|
||||||
@ -392,14 +465,13 @@ class MemoryUsageTracker {
|
|||||||
// point.
|
// point.
|
||||||
bool IsCurrentlyLive(BufferId buffer_id) const {
|
bool IsCurrentlyLive(BufferId buffer_id) const {
|
||||||
const Buffer& buffer = buffers_[buffer_id];
|
const Buffer& buffer = buffers_[buffer_id];
|
||||||
return (IsPlaced(buffer.defining_instruction) &&
|
return (buffer.defining_instruction->placed &&
|
||||||
buffer.unfinished_user_count > 0);
|
buffer.unfinished_user_count > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new buffer, add it to buffers_, and return a reference.
|
// Create a new buffer, add it to buffers_, and return a reference.
|
||||||
Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size,
|
Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
|
||||||
std::vector<const HloInstruction*>&& users, bool live_out,
|
bool live_out, bool has_indirect_uses) {
|
||||||
bool has_indirect_uses) {
|
|
||||||
int buffer_id = buffers_.size();
|
int buffer_id = buffers_.size();
|
||||||
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
|
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
|
||||||
has_indirect_uses, users,
|
has_indirect_uses, users,
|
||||||
@ -419,19 +491,7 @@ class MemoryUsageTracker {
|
|||||||
|
|
||||||
// The instruction currently being placed. This value is non-null only
|
// The instruction currently being placed. This value is non-null only
|
||||||
// between the calling of BeginInstruction and EndInstruction.
|
// between the calling of BeginInstruction and EndInstruction.
|
||||||
const HloInstruction* in_progress_instruction_ = nullptr;
|
Item* in_progress_item_ = nullptr;
|
||||||
|
|
||||||
// The buffers defined by each instruction.
|
|
||||||
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
|
|
||||||
buffers_defined_by_instruction_;
|
|
||||||
|
|
||||||
// The buffers used by each instruction.
|
|
||||||
std::unordered_map<const HloInstruction*, std::vector<BufferId>>
|
|
||||||
buffers_used_by_instruction_;
|
|
||||||
|
|
||||||
// The set of instructions which have been placed. That is, BeginInstruction
|
|
||||||
// has been called with the instruction as an argument.
|
|
||||||
tensorflow::gtl::FlatSet<const HloInstruction*> placed_instructions_;
|
|
||||||
|
|
||||||
// All buffers in the computation.
|
// All buffers in the computation.
|
||||||
std::vector<Buffer> buffers_;
|
std::vector<Buffer> buffers_;
|
||||||
@ -443,22 +503,15 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const InstructionList& instruction_list)
|
const InstructionList& instruction_list)
|
||||||
: computation_(computation), instruction_list_(instruction_list) {
|
: computation_(computation), instruction_list_(instruction_list) {
|
||||||
// Iterate through all LogicalBuffers in the computation and gather the
|
|
||||||
// instructions which define them in buffers_defined_by_instruction_ and the
|
|
||||||
// instructions which use them in buffers_used_by_instruction_.
|
|
||||||
for (auto& instruction : computation_->instructions()) {
|
|
||||||
// Initialize empty vectors for defs and uses of each instruction.
|
|
||||||
buffers_used_by_instruction_[instruction.get()];
|
|
||||||
buffers_defined_by_instruction_[instruction.get()];
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
|
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_out_set =
|
||||||
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
||||||
.CreateFlattenedSet();
|
.CreateFlattenedSet();
|
||||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
|
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
|
||||||
logical_buffer_to_buffer_id;
|
logical_buffer_to_buffer_id;
|
||||||
|
|
||||||
for (const HloInstruction* instruction : instruction_list_.instructions()) {
|
for (auto* item = instruction_list_.first(); item != nullptr;
|
||||||
|
item = instruction_list_.next(item)) {
|
||||||
|
const HloInstruction* const instruction = item->instruction;
|
||||||
for (const LogicalBuffer* logical_buffer :
|
for (const LogicalBuffer* logical_buffer :
|
||||||
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
||||||
Buffer* buffer;
|
Buffer* buffer;
|
||||||
@ -481,22 +534,22 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
|
|
||||||
// Add users of while to Buffer users.
|
// Add users of while to Buffer users.
|
||||||
bool unused;
|
bool unused;
|
||||||
for (const HloInstruction* user :
|
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
||||||
GetUsers(logical_buffer, points_to_analysis, &unused)) {
|
points_to_analysis, &unused)) {
|
||||||
if (std::find(buffer->users.begin(), buffer->users.end(), user) ==
|
if (std::find(buffer->users.begin(), buffer->users.end(),
|
||||||
buffer->users.end()) {
|
user_item) == buffer->users.end()) {
|
||||||
buffer->users.push_back(user);
|
buffer->users.push_back(user_item);
|
||||||
buffer->unfinished_user_count++;
|
buffer->unfinished_user_count++;
|
||||||
buffers_used_by_instruction_.at(user).push_back(buffer->id);
|
user_item->buffers_used.push_back(buffer->id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
buffer = &CreateBufferFromLogicalBuffer(
|
buffer = &CreateBufferFromLogicalBuffer(
|
||||||
logical_buffer, points_to_analysis, size_function,
|
logical_buffer, points_to_analysis, size_function,
|
||||||
ContainsKey(live_out_set, logical_buffer));
|
ContainsKey(live_out_set, logical_buffer));
|
||||||
buffers_defined_by_instruction_.at(instruction).push_back(buffer->id);
|
item->buffers_defined.push_back(buffer->id);
|
||||||
for (const HloInstruction* user : buffer->users) {
|
for (Item* user : buffer->users) {
|
||||||
buffers_used_by_instruction_.at(user).push_back(buffer->id);
|
user->buffers_used.push_back(buffer->id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -507,15 +560,16 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
DCHECK(Check());
|
DCHECK(Check());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
|
Status MemoryUsageTracker::BeginInstruction(Item* item) {
|
||||||
|
const HloInstruction* instruction = item->instruction;
|
||||||
VLOG(3) << "BeginInstruction " << instruction->name();
|
VLOG(3) << "BeginInstruction " << instruction->name();
|
||||||
TF_RET_CHECK(in_progress_instruction_ == nullptr);
|
TF_RET_CHECK(in_progress_item_ == nullptr);
|
||||||
in_progress_instruction_ = instruction;
|
in_progress_item_ = item;
|
||||||
|
|
||||||
placed_instructions_.insert(in_progress_instruction_);
|
item->placed = true;
|
||||||
|
|
||||||
// All buffers defined by this instruction need memory.
|
// All buffers defined by this instruction need memory.
|
||||||
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
|
for (BufferId buffer_id : item->buffers_defined) {
|
||||||
VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
|
VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
|
||||||
<< " is now live.";
|
<< " is now live.";
|
||||||
memory_usage_ += AllocatedSize(buffer_id);
|
memory_usage_ += AllocatedSize(buffer_id);
|
||||||
@ -532,11 +586,10 @@ Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status MemoryUsageTracker::EndInstruction() {
|
Status MemoryUsageTracker::EndInstruction() {
|
||||||
TF_RET_CHECK(in_progress_instruction_ != nullptr);
|
TF_RET_CHECK(in_progress_item_ != nullptr);
|
||||||
VLOG(3) << "EndInstruction " << in_progress_instruction_->name();
|
VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
|
||||||
|
|
||||||
for (BufferId buffer_id :
|
for (BufferId buffer_id : in_progress_item_->buffers_used) {
|
||||||
buffers_used_by_instruction_.at(in_progress_instruction_)) {
|
|
||||||
Buffer& buffer = buffers_.at(buffer_id);
|
Buffer& buffer = buffers_.at(buffer_id);
|
||||||
buffer.unfinished_user_count--;
|
buffer.unfinished_user_count--;
|
||||||
CHECK_GE(buffer.unfinished_user_count, 0)
|
CHECK_GE(buffer.unfinished_user_count, 0)
|
||||||
@ -551,8 +604,7 @@ Status MemoryUsageTracker::EndInstruction() {
|
|||||||
|
|
||||||
// If any buffer defined by this instruction has no uses, then memory can be
|
// If any buffer defined by this instruction has no uses, then memory can be
|
||||||
// reclaimed immediately.
|
// reclaimed immediately.
|
||||||
for (BufferId buffer_id :
|
for (BufferId buffer_id : in_progress_item_->buffers_defined) {
|
||||||
buffers_defined_by_instruction_.at(in_progress_instruction_)) {
|
|
||||||
const Buffer& buffer = buffers_.at(buffer_id);
|
const Buffer& buffer = buffers_.at(buffer_id);
|
||||||
if (buffer.unfinished_user_count == 0) {
|
if (buffer.unfinished_user_count == 0) {
|
||||||
VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
|
VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
|
||||||
@ -561,7 +613,7 @@ Status MemoryUsageTracker::EndInstruction() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
in_progress_instruction_ = nullptr;
|
in_progress_item_ = nullptr;
|
||||||
|
|
||||||
VLOG(3) << " memory usage = " << memory_usage_;
|
VLOG(3) << " memory usage = " << memory_usage_;
|
||||||
VLOG(10) << ToString();
|
VLOG(10) << ToString();
|
||||||
@ -571,10 +623,9 @@ Status MemoryUsageTracker::EndInstruction() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
|
||||||
const HloInstruction* instruction) const {
|
CHECK_NE(in_progress_item_, nullptr);
|
||||||
CHECK_NE(in_progress_instruction_, nullptr);
|
if (!item->placed || item == in_progress_item_) {
|
||||||
if (!IsPlaced(instruction) || instruction == in_progress_instruction_) {
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -589,7 +640,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
|||||||
// be live at this program point, so initially set memory_reduced to the
|
// be live at this program point, so initially set memory_reduced to the
|
||||||
// size of its defined values.
|
// size of its defined values.
|
||||||
int64 memory_reduced = 0;
|
int64 memory_reduced = 0;
|
||||||
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
|
for (BufferId buffer_id : item->buffers_defined) {
|
||||||
// Avoid rematerializing instructions with indirect uses as it is difficult
|
// Avoid rematerializing instructions with indirect uses as it is difficult
|
||||||
// to reason about liveness after rematerializing the instruction.
|
// to reason about liveness after rematerializing the instruction.
|
||||||
// TODO(b/37714814): Consider rematerialzing instructions with indirect
|
// TODO(b/37714814): Consider rematerialzing instructions with indirect
|
||||||
@ -605,7 +656,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
|||||||
|
|
||||||
// Account for any logical buffers whose live range must be extended across
|
// Account for any logical buffers whose live range must be extended across
|
||||||
// this program point.
|
// this program point.
|
||||||
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) {
|
for (BufferId buffer_id : item->buffers_used) {
|
||||||
if (!IsCurrentlyLive(buffer_id)) {
|
if (!IsCurrentlyLive(buffer_id)) {
|
||||||
// This logical buffer is used by 'instruction' but is not live at this
|
// This logical buffer is used by 'instruction' but is not live at this
|
||||||
// program point. Rematerializing 'instruction' will extend the buffer's
|
// program point. Rematerializing 'instruction' will extend the buffer's
|
||||||
@ -617,28 +668,23 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
|||||||
return memory_reduced;
|
return memory_reduced;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MemoryUsageTracker::AddRematerializedInstruction(
|
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||||
HloInstruction* original_instruction, HloInstruction* remat_instruction) {
|
Item* remat_item) {
|
||||||
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
||||||
<< original_instruction->name()
|
<< original_item->instruction->name()
|
||||||
<< ", remat_instruction = " << remat_instruction->name();
|
<< ", remat_instruction = " << remat_item->instruction->name();
|
||||||
|
|
||||||
TF_RET_CHECK(in_progress_instruction_ != nullptr);
|
TF_RET_CHECK(in_progress_item_ != nullptr);
|
||||||
TF_RET_CHECK(IsPlaced(original_instruction));
|
TF_RET_CHECK(original_item->placed);
|
||||||
TF_RET_CHECK(!IsPlaced(remat_instruction));
|
TF_RET_CHECK(!remat_item->placed);
|
||||||
CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction));
|
|
||||||
CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction));
|
|
||||||
|
|
||||||
// Construct the list of buffers used and defined by the rematerialization.
|
// Construct the list of buffers used and defined by the rematerialization.
|
||||||
buffers_defined_by_instruction_[remat_instruction];
|
remat_item->buffers_used = original_item->buffers_used;
|
||||||
buffers_used_by_instruction_[remat_instruction] =
|
|
||||||
buffers_used_by_instruction_.at(original_instruction);
|
|
||||||
|
|
||||||
// Account for the additional buffer uses created by the new rematerialization
|
// Account for the additional buffer uses created by the new rematerialization
|
||||||
// instruction. Update memory usage if the rematerialization makes a dead
|
// instruction. Update memory usage if the rematerialization makes a dead
|
||||||
// buffer live again.
|
// buffer live again.
|
||||||
for (BufferId buffer_id :
|
for (BufferId buffer_id : original_item->buffers_used) {
|
||||||
buffers_used_by_instruction_.at(original_instruction)) {
|
|
||||||
Buffer& buffer = buffers_.at(buffer_id);
|
Buffer& buffer = buffers_.at(buffer_id);
|
||||||
if (buffer.unfinished_user_count == 0) {
|
if (buffer.unfinished_user_count == 0) {
|
||||||
// Buffer used by this instruction was dead, now is alive.
|
// Buffer used by this instruction was dead, now is alive.
|
||||||
@ -646,20 +692,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
buffer.unfinished_user_count++;
|
buffer.unfinished_user_count++;
|
||||||
buffer.users.push_back(remat_instruction);
|
buffer.users.push_back(remat_item);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new set of Buffers defined by the new rematerialization
|
// Create a new set of Buffers defined by the new rematerialization
|
||||||
// instruction. Update the internal data structures and memory use to account
|
// instruction. Update the internal data structures and memory use to account
|
||||||
// for them.
|
// for them.
|
||||||
for (BufferId old_buffer_id :
|
for (BufferId old_buffer_id : original_item->buffers_defined) {
|
||||||
buffers_defined_by_instruction_.at(original_instruction)) {
|
|
||||||
Buffer& old_buffer = buffers_.at(old_buffer_id);
|
Buffer& old_buffer = buffers_.at(old_buffer_id);
|
||||||
|
|
||||||
std::vector<const HloInstruction*> placed_users;
|
ItemList placed_users;
|
||||||
std::vector<const HloInstruction*> unplaced_users;
|
ItemList unplaced_users;
|
||||||
for (const HloInstruction* user : old_buffer.users) {
|
for (Item* user : old_buffer.users) {
|
||||||
if (IsPlaced(user)) {
|
if (user->placed) {
|
||||||
CHECK(IsFinished(user));
|
CHECK(IsFinished(user));
|
||||||
placed_users.push_back(user);
|
placed_users.push_back(user);
|
||||||
} else {
|
} else {
|
||||||
@ -672,14 +717,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction(
|
|||||||
// Buffer is now dead.
|
// Buffer is now dead.
|
||||||
memory_usage_ -= AllocatedSize(old_buffer.id);
|
memory_usage_ -= AllocatedSize(old_buffer.id);
|
||||||
|
|
||||||
Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction,
|
Buffer& new_buffer =
|
||||||
std::move(unplaced_users));
|
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
|
||||||
|
|
||||||
buffers_defined_by_instruction_.at(remat_instruction)
|
remat_item->buffers_defined.push_back(new_buffer.id);
|
||||||
.push_back(new_buffer.id);
|
for (Item* user : new_buffer.users) {
|
||||||
for (const HloInstruction* user : new_buffer.users) {
|
BufferIdList& buffers_used = user->buffers_used;
|
||||||
std::vector<BufferId>& buffers_used =
|
|
||||||
buffers_used_by_instruction_.at(user);
|
|
||||||
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
|
||||||
new_buffer.id);
|
new_buffer.id);
|
||||||
}
|
}
|
||||||
@ -699,13 +742,14 @@ string MemoryUsageTracker::ToString() const {
|
|||||||
tensorflow::strings::StrAppend(
|
tensorflow::strings::StrAppend(
|
||||||
&output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
|
&output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
|
||||||
memory_usage(), " bytes)");
|
memory_usage(), " bytes)");
|
||||||
for (const HloInstruction* instruction : instruction_list_.instructions()) {
|
for (auto* item = instruction_list_.first(); item != nullptr;
|
||||||
string inprogress =
|
item = instruction_list_.next(item)) {
|
||||||
instruction == in_progress_instruction_ ? " in-progress" : "";
|
const HloInstruction* instruction = item->instruction;
|
||||||
string placed = IsPlaced(instruction) ? " placed" : "";
|
string inprogress = item == in_progress_item_ ? " in-progress" : "";
|
||||||
|
string placed = item->placed ? " placed" : "";
|
||||||
tensorflow::strings::StrAppend(&output, " ", instruction->name(),
|
tensorflow::strings::StrAppend(&output, " ", instruction->name(),
|
||||||
inprogress, placed, "\n Defines:\n");
|
inprogress, placed, "\n Defines:\n");
|
||||||
for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) {
|
for (BufferId buffer_id : item->buffers_defined) {
|
||||||
const Buffer& buffer = buffers_[buffer_id];
|
const Buffer& buffer = buffers_[buffer_id];
|
||||||
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
|
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
|
||||||
tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
|
tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
|
||||||
@ -713,7 +757,7 @@ string MemoryUsageTracker::ToString() const {
|
|||||||
" unfinished uses\n");
|
" unfinished uses\n");
|
||||||
}
|
}
|
||||||
tensorflow::strings::StrAppend(&output, " Uses:\n");
|
tensorflow::strings::StrAppend(&output, " Uses:\n");
|
||||||
for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) {
|
for (BufferId buffer_id : item->buffers_used) {
|
||||||
tensorflow::strings::StrAppend(&output, " ",
|
tensorflow::strings::StrAppend(&output, " ",
|
||||||
buffers_[buffer_id].ToString(), "\n");
|
buffers_[buffer_id].ToString(), "\n");
|
||||||
}
|
}
|
||||||
@ -722,14 +766,14 @@ string MemoryUsageTracker::ToString() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool MemoryUsageTracker::Check() const {
|
bool MemoryUsageTracker::Check() const {
|
||||||
auto elements_are_unique = [](const std::vector<BufferId>& vec) {
|
auto elements_are_unique = [](const BufferIdList& vec) {
|
||||||
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
|
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Verify buffers_defined_by_instruction_.
|
// Verify buffers_defined per instruction.
|
||||||
for (auto& instruction : computation_->instructions()) {
|
for (auto& instruction : computation_->instructions()) {
|
||||||
const std::vector<BufferId>& defined_buffers =
|
const BufferIdList& defined_buffers =
|
||||||
buffers_defined_by_instruction_.at(instruction.get());
|
instruction_list_.GetItem(instruction.get())->buffers_defined;
|
||||||
CHECK(elements_are_unique(defined_buffers))
|
CHECK(elements_are_unique(defined_buffers))
|
||||||
<< "Instruction " << instruction->name()
|
<< "Instruction " << instruction->name()
|
||||||
<< " does not have unique defined buffers: "
|
<< " does not have unique defined buffers: "
|
||||||
@ -740,7 +784,7 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (const Buffer& buffer : buffers_) {
|
for (const Buffer& buffer : buffers_) {
|
||||||
if (buffer.defining_instruction == instruction.get()) {
|
if (buffer.defining_instruction->instruction == instruction.get()) {
|
||||||
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
|
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
|
||||||
buffer.id) != defined_buffers.end())
|
buffer.id) != defined_buffers.end())
|
||||||
<< "Instruction " << instruction->name()
|
<< "Instruction " << instruction->name()
|
||||||
@ -749,10 +793,10 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify buffers_used_by_instruction_.
|
// Verify buffers_used per instruction.
|
||||||
for (auto& instruction : computation_->instructions()) {
|
for (auto& instruction : computation_->instructions()) {
|
||||||
const std::vector<BufferId>& used_buffers =
|
const BufferIdList& used_buffers =
|
||||||
buffers_used_by_instruction_.at(instruction.get());
|
instruction_list_.GetItem(instruction.get())->buffers_used;
|
||||||
CHECK(elements_are_unique(used_buffers))
|
CHECK(elements_are_unique(used_buffers))
|
||||||
<< "Instruction " << instruction->name()
|
<< "Instruction " << instruction->name()
|
||||||
<< " does not have unique used buffers: "
|
<< " does not have unique used buffers: "
|
||||||
@ -764,13 +808,12 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
}
|
}
|
||||||
for (const Buffer& buffer : buffers_) {
|
for (const Buffer& buffer : buffers_) {
|
||||||
int64 unfinished_uses = 0;
|
int64 unfinished_uses = 0;
|
||||||
for (const HloInstruction* user : buffer.users) {
|
for (Item* user : buffer.users) {
|
||||||
const std::vector<BufferId>& used_buffers =
|
const BufferIdList& used_buffers = user->buffers_used;
|
||||||
buffers_used_by_instruction_.at(user);
|
|
||||||
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
|
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
|
||||||
used_buffers.end())
|
used_buffers.end())
|
||||||
<< "Instruction " << user->name() << " used buffers is missing "
|
<< "Instruction " << user->instruction->name()
|
||||||
<< buffer.ToString();
|
<< " used buffers is missing " << buffer.ToString();
|
||||||
if (!IsFinished(user)) {
|
if (!IsFinished(user)) {
|
||||||
unfinished_uses++;
|
unfinished_uses++;
|
||||||
}
|
}
|
||||||
@ -785,8 +828,8 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
// The while instruction reuses its input buffers as output buffers so
|
// The while instruction reuses its input buffers as output buffers so
|
||||||
// don't double count its buffers if it is currently executing.
|
// don't double count its buffers if it is currently executing.
|
||||||
if (IsCurrentlyLive(buffer.id) &&
|
if (IsCurrentlyLive(buffer.id) &&
|
||||||
!(buffer.defining_instruction == in_progress_instruction_ &&
|
!(buffer.defining_instruction == in_progress_item_ &&
|
||||||
in_progress_instruction_->opcode() == HloOpcode::kWhile)) {
|
in_progress_item_->instruction->opcode() == HloOpcode::kWhile)) {
|
||||||
live_size += AllocatedSize(buffer.id);
|
live_size += AllocatedSize(buffer.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -830,26 +873,26 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
|||||||
// candidate which reduce memory use at the program point of the current
|
// candidate which reduce memory use at the program point of the current
|
||||||
// instruction as indicated by memory_tracker. nullptr is returned if no
|
// instruction as indicated by memory_tracker. nullptr is returned if no
|
||||||
// candidate can be found.
|
// candidate can be found.
|
||||||
HloInstruction* PickRematerializationCandidate(
|
Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
|
||||||
const MemoryUsageTracker& memory_tracker,
|
const InstructionList& instruction_list,
|
||||||
const InstructionList& instruction_list,
|
int64 memory_limit_bytes) {
|
||||||
const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist,
|
Item* best_item = nullptr;
|
||||||
int64 memory_limit_bytes) {
|
|
||||||
HloInstruction* best = nullptr;
|
|
||||||
int64 best_cost = 0;
|
int64 best_cost = 0;
|
||||||
|
|
||||||
// TODO(b/35244891): This is currently quadratic in the number of HLO
|
// TODO(b/35244891): This is currently quadratic in the number of HLO
|
||||||
// instructions.
|
// instructions.
|
||||||
for (HloInstruction* candidate : instruction_list.instructions()) {
|
for (auto* item = instruction_list.first(); item != nullptr;
|
||||||
if (!memory_tracker.IsPlaced(candidate)) {
|
item = instruction_list.next(item)) {
|
||||||
// Only iterate up to the currently placed instruction as indicated by
|
if (!item->placed) {
|
||||||
// memory_tracker. We are trying to reduce memory usage at the placed
|
// Only iterate up to the currently placed instruction.
|
||||||
|
// We are trying to reduce memory usage at the placed
|
||||||
// instruction so rematerializing later values is of no benefit.
|
// instruction so rematerializing later values is of no benefit.
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
HloInstruction* candidate = item->instruction;
|
||||||
VLOG(5) << "considering rematerialization candidate " << candidate->name();
|
VLOG(5) << "considering rematerialization candidate " << candidate->name();
|
||||||
|
|
||||||
if (ContainsKey(blacklist, candidate)) {
|
if (item->blacklisted) {
|
||||||
// Skip instructions on the blacklist to avoid infinite loops of
|
// Skip instructions on the blacklist to avoid infinite loops of
|
||||||
// rematerializing the same instruction(s) repeatedly.
|
// rematerializing the same instruction(s) repeatedly.
|
||||||
VLOG(5) << "candidate " << candidate->name()
|
VLOG(5) << "candidate " << candidate->name()
|
||||||
@ -864,7 +907,7 @@ HloInstruction* PickRematerializationCandidate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int64 memory_reduced =
|
const int64 memory_reduced =
|
||||||
memory_tracker.MemoryReducedIfRematerialized(candidate);
|
memory_tracker.MemoryReducedIfRematerialized(item);
|
||||||
|
|
||||||
if (memory_reduced <= 0) {
|
if (memory_reduced <= 0) {
|
||||||
VLOG(5) << "candidate " << candidate->name()
|
VLOG(5) << "candidate " << candidate->name()
|
||||||
@ -878,13 +921,13 @@ HloInstruction* PickRematerializationCandidate(
|
|||||||
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
|
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
|
||||||
<< memory_reduced << ", cost per byte " << cost;
|
<< memory_reduced << ", cost per byte " << cost;
|
||||||
|
|
||||||
if (best == nullptr || cost < best_cost) {
|
if (best_item == nullptr || cost < best_cost) {
|
||||||
VLOG(5) << "candidate " << candidate->name() << " now best";
|
VLOG(5) << "candidate " << candidate->name() << " now best";
|
||||||
best = candidate;
|
best_item = item;
|
||||||
best_cost = cost;
|
best_cost = cost;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return best;
|
return best_item;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -896,8 +939,10 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
|||||||
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
||||||
instruction_list);
|
instruction_list);
|
||||||
int64 peak_memory = tracker.memory_usage();
|
int64 peak_memory = tracker.memory_usage();
|
||||||
for (const HloInstruction* instruction : order) {
|
for (auto* item = instruction_list.first(); item != nullptr;
|
||||||
TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction));
|
item = instruction_list.next(item)) {
|
||||||
|
const HloInstruction* instruction = item->instruction;
|
||||||
|
TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
|
||||||
TF_ASSIGN_OR_RETURN(int64 callee_usage,
|
TF_ASSIGN_OR_RETURN(int64 callee_usage,
|
||||||
CalledComputationsMemoryUsage(instruction));
|
CalledComputationsMemoryUsage(instruction));
|
||||||
peak_memory =
|
peak_memory =
|
||||||
@ -939,11 +984,6 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
*points_to_analysis_, instruction_list);
|
*points_to_analysis_, instruction_list);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
||||||
// To avoid an infinite loop rematerializing the same set of instructions ad
|
|
||||||
// infinitum, keep a blacklist of instructions which should not be
|
|
||||||
// rematerialized.
|
|
||||||
tensorflow::gtl::FlatSet<const HloInstruction*> blacklist;
|
|
||||||
|
|
||||||
// If the rematerialization makes the source instruction dead, then the
|
// If the rematerialization makes the source instruction dead, then the
|
||||||
// rematerialization is added to 'remat_move_instructions' (the
|
// rematerialization is added to 'remat_move_instructions' (the
|
||||||
// rematerialization is essentially a move). If the next rematerialization of
|
// rematerialization is essentially a move). If the next rematerialization of
|
||||||
@ -967,17 +1007,17 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
// (program point) if memory_usage exceeds the specified limit then
|
// (program point) if memory_usage exceeds the specified limit then
|
||||||
// rematerialize HLO instructions until memory_usage is reduced.
|
// rematerialize HLO instructions until memory_usage is reduced.
|
||||||
int64 instruction_index = 0;
|
int64 instruction_index = 0;
|
||||||
for (auto list_it = instruction_list.instructions().begin();
|
for (auto* item = instruction_list.first(); item != nullptr;
|
||||||
list_it != instruction_list.instructions().end(); ++list_it) {
|
item = instruction_list.next(item)) {
|
||||||
HloInstruction* instruction = *list_it;
|
const HloInstruction* instruction = item->instruction;
|
||||||
TF_ASSIGN_OR_RETURN(int64 callee_usage,
|
TF_ASSIGN_OR_RETURN(int64 callee_usage,
|
||||||
CalledComputationsMemoryUsage(instruction));
|
CalledComputationsMemoryUsage(instruction));
|
||||||
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(instruction));
|
TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
|
||||||
|
|
||||||
VLOG(2) << "Program point at " << instruction->name()
|
VLOG(2) << "Program point at " << instruction->name()
|
||||||
<< ", memory usage = " << memory_tracker.memory_usage()
|
<< ", memory usage = " << memory_tracker.memory_usage()
|
||||||
<< ", callee usage = " << callee_usage << ", [" << instruction_index
|
<< ", callee usage = " << callee_usage << ", [" << instruction_index
|
||||||
<< "/" << instruction_list.instructions().size() << "]";
|
<< "/" << instruction_list.size() << "]";
|
||||||
instruction_index++;
|
instruction_index++;
|
||||||
|
|
||||||
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
|
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
|
||||||
@ -987,10 +1027,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
callee_usage)
|
callee_usage)
|
||||||
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
|
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
|
||||||
|
|
||||||
HloInstruction* best = PickRematerializationCandidate(
|
Item* best_item = PickRematerializationCandidate(
|
||||||
memory_tracker, instruction_list, blacklist, memory_limit_bytes);
|
memory_tracker, instruction_list, memory_limit_bytes);
|
||||||
|
|
||||||
if (best == nullptr) {
|
if (best_item == nullptr) {
|
||||||
VLOG(3) << "Unable to find rematerialization candidate at program "
|
VLOG(3) << "Unable to find rematerialization candidate at program "
|
||||||
"point "
|
"point "
|
||||||
<< instruction->name() << ". Memory usage = "
|
<< instruction->name() << ". Memory usage = "
|
||||||
@ -999,13 +1039,15 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloInstruction* best = best_item->instruction;
|
||||||
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
|
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
|
||||||
<< memory_tracker.MemoryReducedIfRematerialized(best) << ")";
|
<< memory_tracker.MemoryReducedIfRematerialized(best_item) << ")";
|
||||||
changed = true;
|
changed = true;
|
||||||
remat_count++;
|
remat_count++;
|
||||||
|
|
||||||
HloInstruction* remat =
|
HloInstruction* remat =
|
||||||
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
|
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
|
||||||
|
Item* remat_item = instruction_list.CreateItem(remat);
|
||||||
|
|
||||||
// Replace each remaining use of 'best' with the rematerialization.
|
// Replace each remaining use of 'best' with the rematerialization.
|
||||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
std::vector<HloInstruction*> best_users_copy = best->users();
|
||||||
@ -1019,22 +1061,28 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
|
|
||||||
// Account for the rematerialization in the memory tracker.
|
// Account for the rematerialization in the memory tracker.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
memory_tracker.AddRematerializedInstruction(best, remat));
|
memory_tracker.AddRematerializedInstruction(best_item, remat_item));
|
||||||
|
|
||||||
// Insert rematerialized instruction right before the earliest unplaced
|
// Insert rematerialized instruction right before the earliest unplaced
|
||||||
// use of the instruction *and* the earliest unplaced last use of any
|
// use of the instruction *and* the earliest unplaced last use of any
|
||||||
// operands of remat. Unplaced uses of the remat's operands are included
|
// operands of remat. Unplaced uses of the remat's operands are included
|
||||||
// because we don't want to extend the live range of remat's operands as
|
// because we don't want to extend the live range of remat's operands as
|
||||||
// this could increase memory usage.
|
// this could increase memory usage.
|
||||||
std::vector<HloInstruction*> place_before = remat->users();
|
ItemList place_before;
|
||||||
|
for (auto user : remat->users()) {
|
||||||
|
place_before.push_back(instruction_list.GetItem(user));
|
||||||
|
}
|
||||||
for (auto* operand : remat->operands()) {
|
for (auto* operand : remat->operands()) {
|
||||||
for (auto* operand_user : operand->users()) {
|
for (auto* operand_user : operand->users()) {
|
||||||
if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) {
|
if (operand_user != remat) {
|
||||||
place_before.push_back(operand_user);
|
Item* operand_user_item = instruction_list.GetItem(operand_user);
|
||||||
|
if (!operand_user_item->placed) {
|
||||||
|
place_before.push_back(operand_user_item);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
instruction_list.InsertBeforeInstructions(remat, place_before);
|
instruction_list.InsertBeforeInstructions(remat_item, place_before);
|
||||||
|
|
||||||
// If the rematerialized instruction is dead then rematerialization is
|
// If the rematerialized instruction is dead then rematerialization is
|
||||||
// essentially a move. Don't delete the instruction now because we don't
|
// essentially a move. Don't delete the instruction now because we don't
|
||||||
@ -1048,7 +1096,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
// instruction it was a copying of. Now 'remat' is a rematerialization
|
// instruction it was a copying of. Now 'remat' is a rematerialization
|
||||||
// of 'best' and kills 'best'. Stop rematerializing this instruction
|
// of 'best' and kills 'best'. Stop rematerializing this instruction
|
||||||
// to avoid an infinite loop.
|
// to avoid an infinite loop.
|
||||||
blacklist.insert(remat);
|
instruction_list.Blacklist(remat);
|
||||||
}
|
}
|
||||||
remat_move_instructions.insert(remat);
|
remat_move_instructions.insert(remat);
|
||||||
} else {
|
} else {
|
||||||
@ -1116,10 +1164,13 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
computation_peak_memory_.at(computation) = peak_memory;
|
computation_peak_memory_.at(computation) = peak_memory;
|
||||||
|
|
||||||
// Update order to include rematerialized instructions.
|
// Update order to include rematerialized instructions.
|
||||||
sequence->at(computation)
|
auto& dst = sequence->at(computation);
|
||||||
.assign(instruction_list.instructions().begin(),
|
dst.clear();
|
||||||
instruction_list.instructions().end());
|
for (auto* item = instruction_list.first(); item != nullptr;
|
||||||
|
item = instruction_list.next(item)) {
|
||||||
|
const HloInstruction* instruction = item->instruction;
|
||||||
|
dst.push_back(instruction);
|
||||||
|
}
|
||||||
rematerialized_computations_.insert(computation);
|
rematerialized_computations_.insert(computation);
|
||||||
|
|
||||||
instructions_rematerialized_ += remat_count;
|
instructions_rematerialized_ += remat_count;
|
||||||
|
Loading…
Reference in New Issue
Block a user