Compressing rematerialization

Adds a new kind of rematerialization that compresses the node into a compact form, uncompresses it back at a later program point.

PiperOrigin-RevId: 264734096
This commit is contained in:
Yunxing Dai 2019-08-21 17:52:00 -07:00 committed by TensorFlower Gardener
parent e9a9801734
commit 02f3c946f4
3 changed files with 596 additions and 125 deletions

View File

@ -100,6 +100,17 @@ bool CanBeRematerialized(
using BufferId = int64;
using BufferIdList = absl::InlinedVector<BufferId, 3>;
struct RematStrategy {
enum {
// Recompute the node at a later program point.
kRecompute,
// Change the layout into a compact form and uncompress it back at a later
// program point.
kCompress,
} kind;
Shape compact_shape;
};
// We wrap HloInstruction* with an Item that holds auxiliary
// per-instruction state.
struct Item {
@ -117,6 +128,10 @@ struct Item {
// The buffers defined by this instruction.
BufferIdList buffers_defined;
// Output buffers of this instruction. This is used to track outputs by GTE
// instructions (where the instruction doesn't define a buffer).
BufferIdList buffers_output;
// The buffers used by this instruction.
BufferIdList buffers_used;
@ -251,6 +266,34 @@ class InstructionList {
return InsertBefore(to_insert, min_position_item);
}
void InsertAfterInstructions(Item* to_insert,
absl::Span<Item* const> after_instructions) {
VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
<< " after {"
<< absl::StrJoin(after_instructions, ", ",
[](string* out, Item* item) {
absl::StrAppend(out, item->instruction->name());
})
<< "}";
// Find the max position number of any instruction in
// 'after_instructions'.
CHECK(!after_instructions.empty());
Item* max_position_item = nullptr;
for (Item* item : after_instructions) {
if (max_position_item == nullptr ||
item->position > max_position_item->position) {
max_position_item = item;
}
}
if (max_position_item->next == nullptr) {
InsertAfter(to_insert, max_position_item);
} else {
InsertBeforeInstructions(to_insert, {max_position_item->next});
}
}
void Blacklist(const HloInstruction* inst) {
GetItem(inst)->blacklisted = true;
}
@ -276,6 +319,24 @@ class InstructionList {
item->position = before->position;
}
void InsertAfter(Item* item, Item* after) {
VLOG(3) << "InsertAfter: " << item->instruction->name() << " after "
<< after->instruction->name();
// Insert new item into linked list.
item->next = after->next;
item->prev = after;
after->next = item;
if (item->next != nullptr) {
item->next->prev = item;
}
// Assign the same position number to the newly added instruction as
// 'before'. This guarantees monotonicity of the position numbers, but not
// uniqueness.
item->position = after->position;
}
Item* first_;
// Item for each instruction.
@ -327,6 +388,7 @@ class MemoryUsageTracker {
MemoryUsageTracker(
const HloComputation* computation,
const HloRematerialization::ShapeSizeFunction& size_function,
const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list);
@ -338,6 +400,22 @@ class MemoryUsageTracker {
// EndInstruction memory for dead operand(s) is freed.
Status BeginInstruction(Item* item);
int64 RematerializationCost(const HloInstruction* instruction,
int64 memory_reduced, int64 memory_limit_bytes) {
// If none of the users of 'instruction' have been placed in the sequence
// (as tracked by memory_tracker), then rematerialization of 'instruction'
// is a zero-cost move of 'instruction' in the sequence.
if (!absl::c_any_of(
instruction->users(),
[this](const HloInstruction* inst) { return IsPlaced(inst); })) {
return 0;
}
CHECK_GT(memory_reduced, 0);
// Return the inverse of the benefit of rematerialization.
return memory_limit_bytes / memory_reduced;
}
// Finishes the placement of the current instruction. This frees any dead
// operands or dead result of the instruction. This must be called after
// each call to BeginInstruction.
@ -347,17 +425,28 @@ class MemoryUsageTracker {
// if the given instruction is rematerialized.
int64 MemoryReducedIfRematerialized(Item* item) const;
// Returns the number of bytes that the current memory usage will be reduced
// if the given instruction is compact.
int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
// Returns the number of bytes that the current memory usage will be reduced
// by if the given sequence of instructions is rematerialized.
int64 MemoryReducedIfRematerialized(const absl::Span<Item*>& items) const;
Status AddCompressInstructions(Item* original_item, Item* compressed_item,
Item* uncompressed_item);
// Adjusts memory usage to account for the rematerialization of
// original_item for all remaining unplaced uses. The rematerialization
// is remat_item. This method should be called after the HLO graph has
// been transformed (rematerialization instruction created and connected to
// uses).
// been transformed (rematerialization instruction created and connected
// to uses).
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
std::pair<Item*, RematStrategy> PickRematerializationCandidate(
const InstructionList& instruction_list, int64 memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
// Returns whether the given instruction has been placed (BeginInstruction
// has been called with 'instruction' as the argument).
bool IsPlaced(const HloInstruction* instruction) const {
@ -390,6 +479,9 @@ class MemoryUsageTracker {
// The materialized size of the buffer in bytes.
const int64 size;
// Shape of the buffer.
Shape shape;
// Whether this buffer is live-out of the computation.
bool live_out;
@ -412,19 +504,21 @@ class MemoryUsageTracker {
}
};
// Get the compact shape of given hlo instruction. An internal cache is used
// to avoid computing the shape multiple times.
StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
// Creates a Buffer representing the given logical buffer. The buffer is added
// to buffers_ and a reference is returned.
Buffer& CreateBufferFromLogicalBuffer(
const LogicalBuffer* logical_buffer,
const TuplePointsToAnalysis& points_to_analysis,
const HloRematerialization::ShapeSizeFunction& size_function,
bool live_out) {
const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
bool has_indirect_uses = false;
ItemList users = GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &has_indirect_uses);
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
size_function(logical_buffer->shape()), std::move(users),
live_out, has_indirect_uses);
logical_buffer->shape(), std::move(users), live_out,
has_indirect_uses);
}
// Create a new buffer representing a rematerialization of given buffer for
@ -438,7 +532,7 @@ class MemoryUsageTracker {
for (Item* use : rematerialized_uses) {
CHECK(!use->placed) << use->instruction->name();
}
return NewBuffer(remat_item, original_buffer.size,
return NewBuffer(remat_item, original_buffer.shape,
std::move(rematerialized_uses), /*live_out=*/false,
/*has_indirect_uses=*/false);
}
@ -449,7 +543,8 @@ class MemoryUsageTracker {
// different computation.
int64 AllocatedSize(BufferId buffer_id) const {
const Buffer& buffer = buffers_.at(buffer_id);
HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
HloInstruction* inst = buffer.defining_instruction->instruction;
HloOpcode def_opcode = inst->opcode();
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
return 0;
} else {
@ -482,12 +577,12 @@ class MemoryUsageTracker {
}
// Create a new buffer, add it to buffers_, and return a reference.
Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
bool live_out, bool has_indirect_uses) {
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
ItemList&& users, bool live_out, bool has_indirect_uses) {
int buffer_id = buffers_.size();
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
has_indirect_uses, users,
static_cast<int64>(users.size())});
buffers_.push_back(Buffer{
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
has_indirect_uses, users, static_cast<int64>(users.size())});
return buffers_.back();
}
@ -498,6 +593,16 @@ class MemoryUsageTracker {
// (BeginInstruction/EndInstruction calls).
const InstructionList& instruction_list_;
// Size function returns the bytes of a given buffer.
const HloRematerialization::ShapeSizeFunction& size_function_;
// Converts a shape into compact form, returns the same shape if a shape is
// already considered compact.
const HloRematerialization::CompactShapeFunction& compact_shape_function_;
// A map that caches existing known compact shape for each instruction.
absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
// Memory usage at the currently placed instruction.
int64 memory_usage_ = 0;
@ -512,9 +617,13 @@ class MemoryUsageTracker {
MemoryUsageTracker::MemoryUsageTracker(
const HloComputation* computation,
const HloRematerialization::ShapeSizeFunction& size_function,
const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list)
: computation_(computation), instruction_list_(instruction_list) {
: computation_(computation),
instruction_list_(instruction_list),
size_function_(size_function),
compact_shape_function_(compact_shape_function) {
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
@ -556,7 +665,7 @@ MemoryUsageTracker::MemoryUsageTracker(
}
} else {
buffer = &CreateBufferFromLogicalBuffer(
logical_buffer, points_to_analysis, size_function,
logical_buffer, points_to_analysis,
ContainsKey(live_out_set, logical_buffer));
item->buffers_defined.push_back(buffer->id);
for (Item* user : buffer->users) {
@ -566,6 +675,14 @@ MemoryUsageTracker::MemoryUsageTracker(
logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
}
// Trace the output of each instruction. This is so that we can properly
// track which outputs does GTEs have.
for (const LogicalBuffer* logical_buffer :
points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
item->buffers_output.push_back(
logical_buffer_to_buffer_id[logical_buffer]);
}
}
XLA_VLOG_LINES(10, ToString());
DCHECK(Check());
@ -637,6 +754,29 @@ Status MemoryUsageTracker::EndInstruction() {
return Status::OK();
}
int64 MemoryUsageTracker::MemoryReducedIfCompressed(
Item* item, const Shape& compact_shape) const {
CHECK_NE(in_progress_item_, nullptr);
if (!item->placed || item == in_progress_item_) {
return 0;
}
int64 memory_reduced = 0;
// We only compress a single piece of an output at one time.
CHECK_EQ(item->buffers_output.size(), 1);
BufferId buffer_id = item->buffers_output[0];
if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
const Buffer& buffer = buffers_.at(buffer_id);
memory_reduced += buffer.size;
int64 compact_shape_size = size_function_(compact_shape);
// Account for buffers that are compress after instruction.
memory_reduced -= compact_shape_size;
}
return memory_reduced;
}
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
CHECK_NE(in_progress_item_, nullptr);
if (!item->placed || item == in_progress_item_) {
@ -736,6 +876,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
return memory_reduced;
}
Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
Item* compressed_item,
Item* uncompressed_item) {
// Original buffer is now dead.
memory_usage_ -= size_function_(original_item->instruction->shape());
// Compressed buffer is now alive.
memory_usage_ += size_function_(compressed_item->instruction->shape());
ItemList placed_users;
ItemList unplaced_users;
CHECK_EQ(original_item->buffers_output.size(), 1);
BufferId original_buffer_id = original_item->buffers_output[0];
Buffer& original_buffer = buffers_.at(original_buffer_id);
for (Item* user : original_buffer.users) {
if (user->placed) {
CHECK(IsFinished(user)) << user->instruction->name();
placed_users.push_back(user);
} else {
unplaced_users.push_back(user);
}
}
original_buffer.users = std::move(placed_users);
original_buffer.unfinished_user_count = 0;
original_buffer.users.push_back(compressed_item);
Buffer& compressed_buffer =
NewBuffer(compressed_item, compressed_item->instruction->shape(),
{uncompressed_item}, /*live_out=*/false,
/*has_indirect_uses=*/false);
compressed_item->buffers_used = original_item->buffers_output;
compressed_item->buffers_output = {compressed_buffer.id};
compressed_item->buffers_defined.push_back(compressed_buffer.id);
Buffer& uncompressed_buffer =
NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
std::move(unplaced_users), /*live_out=*/false,
/*has_indirect_uses=*/false);
uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
uncompressed_item->buffers_output = {uncompressed_buffer.id};
uncompressed_item->buffers_defined = {uncompressed_buffer.id};
for (Item* user : uncompressed_buffer.users) {
BufferIdList& buffers_used = user->buffers_used;
std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
uncompressed_buffer.id);
}
return Status::OK();
}
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
Item* remat_item) {
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
@ -831,6 +1021,17 @@ string MemoryUsageTracker::ToString() const {
return output;
}
StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
auto it = compact_shape_.find(hlo);
if (it != compact_shape_.end()) {
return it->second;
}
const Shape& original_shape = hlo->shape();
TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
compact_shape_[hlo] = min_shape;
return min_shape;
}
bool MemoryUsageTracker::Check() const {
auto elements_are_unique = [](const BufferIdList& vec) {
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
@ -917,12 +1118,15 @@ int64 RematerializationCost(const HloInstruction* instruction,
// candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found.
Item* PickRematerializationCandidate(
const MemoryUsageTracker& memory_tracker,
std::pair<Item*, RematStrategy>
MemoryUsageTracker::PickRematerializationCandidate(
const InstructionList& instruction_list, int64 memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
int64 best_cost = 0;
RematStrategy best_strategy;
VLOG(5) << "Picking candidate";
// TODO(b/35244891): This is currently quadratic in the number of HLO
// instructions.
@ -947,44 +1151,215 @@ Item* PickRematerializationCandidate(
if (!CanBeRematerialized(candidate, remat_able)) {
VLOG(5) << "candidate " << candidate->name()
<< " not viable: is not rematerializable";
continue;
}
// If any of the candidate's control successor has been placed, we need to
// skip this candidate. Otherwise we will violate control dependency.
bool control_successor_placed =
std::any_of(candidate->control_successors().begin(),
candidate->control_successors().end(),
[&memory_tracker](const HloInstruction* inst) {
return memory_tracker.IsPlaced(inst);
});
if (item->buffers_output.size() == 1) {
// Only consider compressing single output instruction.
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
if (item->placed && item != in_progress_item_ &&
!output_buffer.live_out) {
const Shape& original_shape = item->instruction->shape();
if (original_shape.IsArray()) {
Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
const int64 memory_reduced =
MemoryReducedIfCompressed(item, compact_shape);
if (memory_reduced > 0) {
const int64 cost = memory_limit_bytes / memory_reduced;
if (best_item == nullptr || cost < best_cost) {
VLOG(3) << "candidate " << candidate->name() << "("
<< candidate->ToShortString() << ")"
<< " now best when compressed into "
<< compact_shape.ToString(true);
RematStrategy strategy;
strategy.kind = RematStrategy::kCompress;
best_strategy = strategy;
best_strategy.compact_shape = compact_shape;
best_item = item;
best_cost = cost;
}
}
}
}
}
// If any of the candidate's control successor has been placed, we need
// to skip this candidate. Otherwise we will violate control dependency.
bool control_successor_placed = std::any_of(
candidate->control_successors().begin(),
candidate->control_successors().end(),
[this](const HloInstruction* inst) { return IsPlaced(inst); });
if (control_successor_placed) {
continue;
}
const int64 memory_reduced =
memory_tracker.MemoryReducedIfRematerialized(item);
const int64 memory_reduced = MemoryReducedIfRematerialized(item);
if (memory_reduced <= 0) {
VLOG(5) << "candidate " << candidate->name()
<< " memory reduced = " << memory_reduced << " <= 0";
continue;
}
if (memory_reduced > 0) {
const int cost =
RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
const int cost = RematerializationCost(candidate, memory_tracker,
memory_reduced, memory_limit_bytes);
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
<< memory_reduced << ", cost per byte " << cost;
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
<< memory_reduced << ", cost per byte " << cost;
if (best_item == nullptr || cost < best_cost) {
VLOG(5) << "candidate " << candidate->name() << " now best";
best_item = item;
best_cost = cost;
if (best_item == nullptr || cost < best_cost) {
VLOG(5) << "candidate " << candidate->name() << " now best";
best_strategy.kind = RematStrategy::kRecompute;
best_item = item;
best_cost = cost;
}
}
}
return best_item;
return {best_item, best_strategy};
}
StatusOr<int64> RematerializeInstruction(
MemoryUsageTracker* memory_tracker, Item* best_item,
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
InstructionList* instruction_list) {
HloInstruction* best = best_item->instruction;
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
<< HumanReadableNumBytes(
memory_tracker->MemoryReducedIfRematerialized(best_item))
<< ")";
int64 net_instructions_added = 0;
HloComputation* computation = best->parent();
HloInstruction* remat =
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
// Add control dependencies to the new operation.
for (auto successor : best->control_successors()) {
TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
}
for (auto predecessor : best->control_predecessors()) {
TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
}
Item* remat_item = instruction_list->CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization.
std::vector<HloInstruction*> best_users_copy = best->users();
for (HloInstruction* user : best_users_copy) {
if (!memory_tracker->IsPlaced(user)) {
VLOG(2) << " Replacing use of " << best->name() << " in " << user->name()
<< " with " << remat->name();
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
}
}
// Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR(
memory_tracker->AddRematerializedInstruction(best_item, remat_item));
// Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any
// operands of remat. Unplaced uses of the remat's operands are included
// because we don't want to extend the live range of remat's operands as
// this could increase memory usage.
ItemList place_before;
for (auto user : remat->users()) {
place_before.push_back(instruction_list->GetItem(user));
}
for (auto* operand : remat->operands()) {
for (auto* operand_user : operand->users()) {
if (operand_user != remat) {
Item* operand_user_item = instruction_list->GetItem(operand_user);
if (!operand_user_item->placed) {
place_before.push_back(operand_user_item);
}
}
}
}
// Insert rematerialized instruction before any of its successors to
// preserve ordering regarding control dependency.
for (auto successor : remat->control_successors()) {
Item* successor_item = instruction_list->GetItem(successor);
// Assert to make sure we never remat an operation with control
// successor already placed.
CHECK(!successor_item->placed) << successor_item->instruction->name();
place_before.push_back(successor_item);
}
instruction_list->InsertBeforeInstructions(remat_item, place_before);
// If the rematerialized instruction is dead then rematerialization is
// essentially a move. Don't delete the instruction now because we don't
// want duplicate HloInstruction* values during the course of the
// transformation because we keep maps with HloInstruction* values as
// keys.
if (best->users().empty()) {
VLOG(2) << best->name() << " is now dead";
if (ContainsKey(*remat_move_instructions, best)) {
// Previously, 'best' was a rematerialization which killed the
// instruction it was a copying of. Now 'remat' is a rematerialization
// of 'best' and kills 'best'. Stop rematerializing this instruction
// to avoid an infinite loop.
instruction_list->Blacklist(remat);
}
remat_move_instructions->insert(remat);
} else {
net_instructions_added++;
}
return net_instructions_added;
}
StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
Item* best_item, const Shape& compact_shape,
InstructionList* instruction_list) {
HloInstruction* best = best_item->instruction;
VLOG(5) << "Transposing instruction " << best->name() << " (saving "
<< HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
best_item, compact_shape))
<< ") to" << compact_shape.ToString(true);
HloComputation* computation = best->parent();
HloInstruction* compressed = computation->AddInstruction(
HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
HloInstruction* uncompressed = computation->AddInstruction(
HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
Item* compressed_item = instruction_list->CreateItem(compressed);
compressed_item->placed = true;
Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
// Replace each remaining use of 'best' with the uncompressed.
std::vector<HloInstruction*> best_users_copy = best->users();
for (HloInstruction* user : best_users_copy) {
if (!memory_tracker->IsPlaced(user)) {
VLOG(5) << " Replacing use of " << best->name() << " in " << user->name()
<< " with " << uncompressed->name();
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
}
}
// Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
best_item, compressed_item, uncompressed_item));
// Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any
// operands of remat. Unplaced uses of the remat's operands are included
// because we don't want to extend the live range of remat's operands as
// this could increase memory usage.
ItemList place_before;
for (auto user : uncompressed->users()) {
place_before.push_back(instruction_list->GetItem(user));
}
instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
instruction_list->InsertAfterInstructions(compressed_item, {best_item});
return 2;
}
} // namespace
@ -993,7 +1368,8 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
const HloComputation* computation,
const HloInstructionSequence& order) const {
InstructionList instruction_list(order);
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
MemoryUsageTracker tracker(computation, size_function_,
compact_shape_function_, *points_to_analysis_,
instruction_list);
int64 peak_memory = tracker.memory_usage();
for (auto* item = instruction_list.first(); item != nullptr;
@ -1037,6 +1413,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
InstructionList instruction_list(schedule->sequence(computation));
MemoryUsageTracker memory_tracker(computation, size_function_,
compact_shape_function_,
*points_to_analysis_, instruction_list);
bool changed = false;
@ -1086,8 +1463,11 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
Item* best_item = PickRematerializationCandidate(
memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
Item* best_item;
RematStrategy best_strategy;
std::tie(best_item, best_strategy) =
memory_tracker.PickRematerializationCandidate(
instruction_list, memory_limit_bytes, &remat_able);
if (best_item == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program "
@ -1106,81 +1486,19 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
changed = true;
remat_count++;
HloInstruction* remat =
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
// Add control dependencies to the new operation.
for (auto successor : best->control_successors()) {
TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
}
for (auto predecessor : best->control_predecessors()) {
TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
}
Item* remat_item = instruction_list.CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization.
std::vector<HloInstruction*> best_users_copy = best->users();
for (HloInstruction* user : best_users_copy) {
if (!memory_tracker.IsPlaced(user)) {
VLOG(2) << " Replacing use of " << best->name() << " in "
<< user->name() << " with " << remat->name();
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
}
}
// Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR(
memory_tracker.AddRematerializedInstruction(best_item, remat_item));
// Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any
// operands of remat. Unplaced uses of the remat's operands are included
// because we don't want to extend the live range of remat's operands as
// this could increase memory usage.
ItemList place_before;
for (auto user : remat->users()) {
place_before.push_back(instruction_list.GetItem(user));
}
for (auto* operand : remat->operands()) {
for (auto* operand_user : operand->users()) {
if (operand_user != remat) {
Item* operand_user_item = instruction_list.GetItem(operand_user);
if (!operand_user_item->placed) {
place_before.push_back(operand_user_item);
}
}
}
}
// Insert rematerialized instruction before any of its successors to
// preserve ordering regarding control dependency.
for (auto successor : remat->control_successors()) {
Item* successor_item = instruction_list.GetItem(successor);
// Assert to make sure we never remat an operation with control
// successor already placed.
CHECK(!successor_item->placed) << successor_item->instruction->name();
place_before.push_back(successor_item);
}
instruction_list.InsertBeforeInstructions(remat_item, place_before);
// If the rematerialized instruction is dead then rematerialization is
// essentially a move. Don't delete the instruction now because we don't
// want duplicate HloInstruction* values during the course of the
// transformation because we keep maps with HloInstruction* values as
// keys.
if (best->users().empty()) {
VLOG(2) << best->name() << " is now dead";
if (ContainsKey(remat_move_instructions, best)) {
// Previously, 'best' was a rematerialization which killed the
// instruction it was a copying of. Now 'remat' is a rematerialization
// of 'best' and kills 'best'. Stop rematerializing this instruction
// to avoid an infinite loop.
instruction_list.Blacklist(remat);
}
remat_move_instructions.insert(remat);
int64 added_instruction = 0;
if (best_strategy.kind == RematStrategy::kCompress) {
TF_ASSIGN_OR_RETURN(added_instruction,
CompressInstruction(&memory_tracker, best_item,
best_strategy.compact_shape,
&instruction_list));
} else {
net_instructions_added++;
TF_ASSIGN_OR_RETURN(added_instruction,
RematerializeInstruction(&memory_tracker, best_item,
&remat_move_instructions,
&instruction_list));
}
net_instructions_added += added_instruction;
VLOG(1) << "memory_usage after rematerialization = "
<< HumanReadableNumBytes(memory_tracker.memory_usage());
@ -1357,7 +1675,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
sizes_->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
if (current_peak_memory > memory_limit_bytes_) {
LOG(WARNING) << absl::StrFormat(

View File

@ -24,6 +24,8 @@
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
@ -38,6 +40,8 @@ class HloRematerialization : public HloModulePass {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>;
// Helper struct that communicates the before / after sizes for the
// rematerialization process.
struct RematerializationSizes {
@ -45,6 +49,8 @@ class HloRematerialization : public HloModulePass {
int64 after_bytes;
};
static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
// Constructor parameters:
//
// size_function: Function which returns the size in bytes of the top-level
@ -57,12 +63,20 @@ class HloRematerialization : public HloModulePass {
// sizes: Pointer to data structure which records the peak memory usage of
// the HLO module before/after rematerialization. Value are set during
// Run(). Can be nullptr.
HloRematerialization(const ShapeSizeFunction& size_function,
int64 memory_limit_bytes, RematerializationSizes* sizes)
//
// compact_shape_function: Function which returns the compact form of a
// shape. If nullptr is provided, an default identity function is used.
explicit HloRematerialization(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
RematerializationSizes* sizes,
CompactShapeFunction compact_shape_function = nullptr)
: size_function_(size_function),
memory_limit_bytes_(memory_limit_bytes),
sizes_(sizes) {}
~HloRematerialization() {}
sizes_(sizes),
compact_shape_function_(compact_shape_function == nullptr
? DefaultCompactShapeFunction
: std::move(compact_shape_function)) {}
~HloRematerialization() override = default;
absl::string_view name() const override { return "rematerialization"; }
@ -109,6 +123,10 @@ class HloRematerialization : public HloModulePass {
// module before/after rematerialization
RematerializationSizes* sizes_;
// Converts a shape into compact form, returns the same shape if a shape is
// already considered compact.
const CompactShapeFunction compact_shape_function_;
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@ -534,6 +533,142 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
::testing::Values(true, false));
class CompressingRematerializationTest : public RematerializationTestBase {
protected:
// A special shape size function, which pads the most minor dimension to 64.
static int64 ShapeSizePadMinorTo64(const Shape& shape) {
if (shape.IsTuple()) {
// Size of a tuple is 4 bytes.
return 4;
}
Shape descending_shape =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape);
int64 size =
ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
for (int64 i = 0; i < descending_shape.rank(); ++i) {
int64 dim = shape.dimensions(i);
if (i == descending_shape.rank() - 1) {
dim = RoundUpToNearest<int64>(dim, 64);
}
size *= dim;
}
return size;
}
// Swap the two most-minor dimensions if the second-minor dimension is bigger
// than the most-minor dimension.
static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
Shape result = shape;
Layout layout = result.layout();
int64 most_minor_index = layout.minor_to_major()[0];
int64 second_minor_index = layout.minor_to_major()[1];
int64 most_minor = result.dimensions(most_minor_index);
int64 second_minor = result.dimensions(second_minor_index);
if (most_minor < second_minor) {
result.set_dimensions(most_minor_index, second_minor);
result.set_dimensions(second_minor_index, most_minor);
}
return result;
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
HloRematerialization remat(ShapeSizePadMinorTo64, memory_limit_bytes,
/*sizes=*/nullptr, ChooseCompactLayoutForShape);
return remat.Run(module);
}
};
// Test rematerialization of a single instruction.
TEST_F(CompressingRematerializationTest, SingleRemat) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024, module.get()));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
HloInstruction* reduce =
module->entry_computation()->GetInstructionWithName("reduce.1");
EXPECT_THAT(reduce,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
}
TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
%reduce.3 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add.2 = f32[] add(f32[] %reduce.2, f32[] %reduce.3)
ROOT %tuple = (f32[], f32[]) tuple (f32[] add, f32[] add.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024, module.get()));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
// Both reduces reuse the same copy instruction.
HloInstruction* reduce_2 =
module->entry_computation()->GetInstructionWithName("reduce.2");
HloInstruction* reduce_3 =
module->entry_computation()->GetInstructionWithName("reduce.3");
EXPECT_THAT(reduce_2,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
EXPECT_THAT(reduce_3,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
}
} // namespace
} // namespace xla