Implements an incremental greedy algorithm for rematerialization. Whenever the memory use at any point in the execution goes over the memory limit, we attempt to rematerialize one instruction at a time to reduce memory usage. If no such candidate is found, we expand the search space to instead rematerialize 2 consecutive instructions at a time, and so on doubling the search space every time. At any point, if we find a block of instructions to rematerialize that reduces the memory usage, we reset the block size to 1 to resume the algorithm.

PiperOrigin-RevId: 293246777
Change-Id: I7488276578e56d1b35b0e963d76b24adefacff8d
This commit is contained in:
A. Unique TensorFlower 2020-02-04 15:47:23 -08:00 committed by TensorFlower Gardener
parent e7dca8b51c
commit 63224a7501
3 changed files with 265 additions and 152 deletions

View File

@ -86,13 +86,13 @@ bool IsRematerializable(const HloInstruction* instruction) {
// cache before, and eventually calling the IsRematerializable() API.
bool CanBeRematerialized(
const HloInstruction* instruction,
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
auto it = remat_able->find(instruction);
if (it != remat_able->end()) {
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
auto it = rematerializable_map->find(instruction);
if (it != rematerializable_map->end()) {
return it->second;
}
bool rematerializable = IsRematerializable(instruction);
(*remat_able)[instruction] = rematerializable;
(*rematerializable_map)[instruction] = rematerializable;
return rematerializable;
}
@ -381,14 +381,22 @@ class MemoryUsageTracker {
// EndInstruction memory for dead operand(s) is freed.
Status BeginInstruction(Item* item);
int64 RematerializationCost(const HloInstruction* instruction,
int64 RematerializationCost(const std::vector<Item*>& items,
int64 memory_reduced, int64 memory_limit_bytes) {
// If none of the users of 'instruction' have been placed in the sequence
// (as tracked by memory_tracker), then rematerialization of 'instruction'
// is a zero-cost move of 'instruction' in the sequence.
if (!absl::c_any_of(
instruction->users(),
[this](const HloInstruction* inst) { return IsPlaced(inst); })) {
// If none of the users of any 'item' have been placed in the
// sequence (as tracked by memory_tracker), then rematerialization of
// 'item' is a zero-cost move of 'item->instruction' in the sequence.
bool zero_cost_move = true;
for (auto* item : items) {
auto* instruction = item->instruction;
if (absl::c_any_of(
instruction->users(),
[this](const HloInstruction* inst) { return IsPlaced(inst); })) {
zero_cost_move = false;
break;
}
}
if (zero_cost_move) {
return 0;
}
@ -425,9 +433,16 @@ class MemoryUsageTracker {
// to uses).
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
std::pair<Item*, RematStrategy> PickRematerializationCandidate(
// Selects and returns the best candidate instructions for rematerialization.
// A sequence of candidate instructions of length between min_block_size and
// max_block_size (both inclusive) with the lowest rematerialization cost is
// selected among those candidates which reduce memory use at the program
// point of the current instruction as indicated by memory_tracker. Returns an
// empty vector if no candidates are found.
std::pair<std::vector<Item*>, RematStrategy> PickRematerializationCandidates(
const InstructionList& instruction_list, int64 memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
int min_block_size, int max_block_size);
// Returns whether the given instruction has been placed (BeginInstruction
// has been called with 'instruction' as the argument).
@ -438,6 +453,9 @@ class MemoryUsageTracker {
// Returns whether 'item' has any unplaced users.
bool HasUnplacedUsers(Item* item) const;
// Returns whether 'item' is currently in progress.
bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
// Returns the current memory usage. This is the sum of sizes of all live
// values.
int64 memory_usage() const { return memory_usage_; }
@ -1121,115 +1139,166 @@ int64 RematerializationCost(const HloInstruction* instruction,
return memory_limit_bytes / memory_reduced;
}
// Selects and returns the best candidate instruction for rematerialization.
// The instruction with lowest rematerialization cost is selected among those
// candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found.
std::pair<Item*, RematStrategy>
MemoryUsageTracker::PickRematerializationCandidate(
// Returns a block of up to min_block_size consecutive candidate instructions
// from instruction_list starting from start_item. Returns fewer than
// min_block_size instructions if the block of unplaced instructions starting
// from start_item is smaller than min_block_size.
std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
const MemoryUsageTracker& tracker,
Item* start_item, int min_block_size) {
std::vector<Item*> item_block;
Item* curr_item = start_item;
for (int i = 0; i < min_block_size; ++i) {
if (curr_item == nullptr || !curr_item->placed ||
tracker.IsInProgressItem(curr_item)) {
break;
}
item_block.push_back(curr_item);
curr_item = instruction_list.next(curr_item);
}
return item_block;
}
// Returns whether any instruction in 'block' is blacklisted or
// non-rematerializable.
bool AnyBlacklistedOrNonRematerializable(
const std::vector<Item*>& block,
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
for (auto* item : block) {
if (item->blacklisted) {
return true;
}
if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
return true;
}
}
return false;
}
std::pair<std::vector<Item*>, RematStrategy>
MemoryUsageTracker::PickRematerializationCandidates(
const InstructionList& instruction_list, int64 memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
int min_block_size, int max_block_size) {
std::vector<Item*> best_items;
int64 best_cost = 0;
RematStrategy best_strategy;
VLOG(5) << "Picking candidate";
VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
<< max_block_size << "]";
// TODO(b/35244891): This is currently quadratic in the number of HLO
// instructions.
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
if (!item->placed) {
// Only iterate up to the currently placed instruction.
// We are trying to reduce memory usage at the placed
// instruction so rematerializing later values is of no benefit.
for (auto* start_item = instruction_list.first(); start_item != nullptr;
start_item = instruction_list.next(start_item)) {
std::vector<Item*> block =
GetInitialBlock(instruction_list, *this, start_item, min_block_size);
if (block.size() < min_block_size) {
// There are no more blocks of size at least min_block_size with unplaced
// instructions.
break;
}
HloInstruction* candidate = item->instruction;
VLOG(5) << "considering rematerialization candidate " << candidate->name();
if (item->blacklisted) {
// Skip instructions on the blacklist to avoid infinite loops of
// rematerializing the same instruction(s) repeatedly.
VLOG(5) << "candidate " << candidate->name()
<< " is excluded from rematerialization";
// If any item in the starting block are blacklisted or non-rematable, then
// break and move on to next start_item (we can actually move to the last
// invalid item in this block, but let's ignore that optimization for now).
if (AnyBlacklistedOrNonRematerializable(block, rematerializable_map)) {
continue;
}
if (!CanBeRematerialized(candidate, remat_able)) {
VLOG(5) << "candidate " << candidate->name()
<< " not viable: is not rematerializable";
while (block.size() <= max_block_size) {
// block size = 1 is treated separately since we consider compression in
// this case only.
if (block.size() == 1) {
auto* item = block[0];
auto* candidate = item->instruction;
if (item->buffers_output.size() == 1 &&
(mode_ ==
HloRematerialization::RematerializationMode::kCompressOnly ||
mode_ == HloRematerialization::RematerializationMode::
kRecomputeAndCompress)) {
// Only consider compressing single output instruction.
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
continue;
}
if (item->buffers_output.size() == 1 &&
(mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
mode_ == HloRematerialization::RematerializationMode::
kRecomputeAndCompress)) {
// Only consider compressing single output instruction.
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
if (item->placed && item != in_progress_item_ &&
!output_buffer.live_out) {
const Shape& original_shape = item->instruction->shape();
if (original_shape.IsArray()) {
Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
const int64 memory_reduced =
MemoryReducedIfCompressed(item, compact_shape);
if (memory_reduced > 0) {
const int64 cost = memory_limit_bytes / memory_reduced;
if (best_item == nullptr || cost < best_cost) {
VLOG(3) << "candidate " << candidate->name() << "("
<< candidate->ToShortString() << ")"
<< " now best when compressed into "
<< compact_shape.ToString(true);
RematStrategy strategy;
strategy.kind = RematStrategy::kCompress;
best_strategy = strategy;
best_strategy.compact_shape = compact_shape;
best_item = item;
best_cost = cost;
if (item->placed && item != in_progress_item_ &&
!output_buffer.live_out) {
const Shape& original_shape = item->instruction->shape();
if (original_shape.IsArray()) {
Shape compact_shape =
GetCompactShape(item->instruction).ValueOrDie();
const int64 memory_reduced =
MemoryReducedIfCompressed(item, compact_shape);
if (memory_reduced > 0) {
const int64 cost = memory_limit_bytes / memory_reduced;
if (best_items.empty() || cost < best_cost) {
VLOG(3) << "candidate " << candidate->name() << "("
<< candidate->ToShortString() << ")"
<< " now best when compressed into "
<< compact_shape.ToString(true);
RematStrategy strategy;
strategy.kind = RematStrategy::kCompress;
best_strategy = strategy;
best_strategy.compact_shape = compact_shape;
best_items = block;
best_cost = cost;
}
}
}
}
}
}
}
// If any of the candidate's control successor has been placed, we need
// to skip this candidate. Otherwise we will violate control dependency.
bool control_successor_placed = std::any_of(
candidate->control_successors().begin(),
candidate->control_successors().end(),
[this](const HloInstruction* inst) { return IsPlaced(inst); });
if (control_successor_placed) {
continue;
}
// Do not consider recomputation in compress-only mode.
if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
continue;
}
const int64 memory_reduced = MemoryReducedIfRematerialized(item);
if (memory_reduced > 0) {
const int cost =
RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
<< memory_reduced << ", cost per byte " << cost;
if (best_item == nullptr || cost < best_cost) {
VLOG(5) << "candidate " << candidate->name() << " now best";
best_strategy.kind = RematStrategy::kRecompute;
best_item = item;
best_cost = cost;
// Do not consider recomputation in compress-only mode.
if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
// break out of this loop. Move on to the next start_item.
break;
}
// If any of the candidate's control successor has been placed, we need
// to skip this candidate. Otherwise we will violate control dependency.
bool control_successor_placed = false;
for (auto* item : block) {
HloInstruction* candidate = item->instruction;
if (std::any_of(candidate->control_successors().begin(),
candidate->control_successors().end(),
[this](const HloInstruction* inst) {
return IsPlaced(inst);
})) {
control_successor_placed = true;
break;
}
}
if (control_successor_placed) {
// break out of this loop. Move on to the next start_item.
break;
}
const int64 memory_reduced = MemoryReducedIfRematerialized(block);
if (memory_reduced > 0) {
const int cost =
RematerializationCost(block, memory_reduced, memory_limit_bytes);
VLOG(5) << "Candidate block of size " << block.size()
<< " starting from " << block[0]->instruction->name()
<< ", memory reduced " << memory_reduced << ", cost per byte "
<< cost;
if (best_items.empty() || cost < best_cost) {
VLOG(5) << "Candidate block of size " << block.size()
<< " starting from " << block[0]->instruction->name()
<< " now best";
best_strategy.kind = RematStrategy::kRecompute;
best_items = block;
best_cost = cost;
}
}
// Time to update the block to include the next instruction.
auto* last_item = block[block.size() - 1];
auto* next_item = instruction_list.next(last_item);
if (next_item == nullptr || next_item->blacklisted ||
!next_item->placed || next_item == in_progress_item_ ||
!CanBeRematerialized(next_item->instruction, rematerializable_map)) {
break;
}
block.push_back(next_item);
}
}
return {best_item, best_strategy};
return {best_items, best_strategy};
}
bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
@ -1402,6 +1471,60 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
return 2;
}
// A simple struct to encapsulate the number of instructions added during
// rematerialization.
struct InstructionsAdded {
// Total count of instructions rematerialized.
int remat_count;
// Total count of instructions rematerialized minus number of original
// instructions that are now dead.
int net_instructions_added;
};
// Rematerializes the best block of instructions of size between min_block_size
// and max_block_size (both inclusive) if at least one candidate block of
// instructions can be found. Returns number of instructions rematerialized.
StatusOr<InstructionsAdded> RematerializeBestBlock(
int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
InstructionList* instruction_list, int64 memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions) {
CHECK(min_block_size > 0) << "Negative block size.";
std::vector<Item*> best_items;
RematStrategy best_strategy;
std::tie(best_items, best_strategy) =
memory_tracker->PickRematerializationCandidates(
*instruction_list, memory_limit_bytes, rematerializable_map,
min_block_size, max_block_size);
InstructionsAdded num_instructions_added;
num_instructions_added.remat_count = best_items.size();
if (best_items.empty()) {
num_instructions_added.net_instructions_added = 0;
return num_instructions_added;
}
if (best_strategy.kind == RematStrategy::kCompress) {
CHECK(best_items.size() == 1)
<< "More than one instruction compressed simultaneously.";
HloInstruction* best = best_items[0]->instruction;
VLOG(1) << "Compressing instruction " << best->name() << " (saving "
<< HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
best_items[0], best_strategy.compact_shape))
<< ")";
TF_ASSIGN_OR_RETURN(
num_instructions_added.net_instructions_added,
CompressInstruction(memory_tracker, best_items[0],
best_strategy.compact_shape, instruction_list));
} else {
TF_ASSIGN_OR_RETURN(
num_instructions_added.net_instructions_added,
RematerializeInstructions(memory_tracker, &best_items,
remat_move_instructions, instruction_list));
}
return num_instructions_added;
}
} // namespace
StatusOr<int64> HloRematerialization::ComputePeakMemory(
@ -1465,7 +1588,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
// The map from instructions to their rematerializable status.
absl::flat_hash_map<const HloInstruction*, bool> remat_able;
absl::flat_hash_map<const HloInstruction*, bool> rematerializable_map;
// The peak memory of the computation at any point in the instruction
// sequence.
@ -1496,6 +1619,11 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< "/" << instruction_list.size() << "]";
instruction_index++;
// Initialize both min_block_size and max_block_size to 1 so that only
// single instruction rematerialization is considered first.
int min_block_size = 1;
int max_block_size = 1;
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
VLOG(2) << "Over memory limit at instruction " << instruction->name()
<< ", using "
@ -1503,53 +1631,31 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
Item* best_item;
RematStrategy best_strategy;
std::tie(best_item, best_strategy) =
memory_tracker.PickRematerializationCandidate(
instruction_list, memory_limit_bytes, &remat_able);
if (best_item == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program "
"point "
<< instruction->name() << ". Memory usage = "
<< HumanReadableNumBytes(memory_tracker.memory_usage() +
callee_usage);
break;
}
HloInstruction* best = best_item->instruction;
changed = true;
remat_count++;
int64 num_instructions_added = 0;
if (best_strategy.kind == RematStrategy::kCompress) {
VLOG(1) << "Compressing instruction " << best->name() << " (saving "
<< HumanReadableNumBytes(
memory_tracker.MemoryReducedIfCompressed(
best_item, best_strategy.compact_shape))
<< ")";
TF_ASSIGN_OR_RETURN(num_instructions_added,
CompressInstruction(&memory_tracker, best_item,
best_strategy.compact_shape,
&instruction_list));
} else {
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
<< HumanReadableNumBytes(
memory_tracker.MemoryReducedIfRematerialized(best_item))
<< ")";
std::vector<Item*> best_items{best_item};
TF_ASSIGN_OR_RETURN(num_instructions_added,
RematerializeInstructions(
&memory_tracker, &best_items,
&remat_move_instructions, &instruction_list));
}
net_instructions_added += num_instructions_added;
TF_ASSIGN_OR_RETURN(InstructionsAdded instructions_added,
RematerializeBestBlock(
min_block_size, max_block_size, &memory_tracker,
&instruction_list, memory_limit_bytes,
&rematerializable_map, &remat_move_instructions));
net_instructions_added += instructions_added.net_instructions_added;
remat_count += instructions_added.remat_count;
VLOG(1) << "memory_usage after rematerialization = "
<< HumanReadableNumBytes(memory_tracker.memory_usage());
if (instructions_added.remat_count == 0) {
// Unable to find a block to rematerialize.
// Consider doubling the block size.
min_block_size = max_block_size + 1;
max_block_size = 2 * max_block_size;
} else {
// Found a valid block. Reset to start looking for single instructions
// again.
changed = true;
min_block_size = 1;
max_block_size = 1;
}
if (max_block_size > block_size_limit_) {
break;
}
}
const CallSite* callsite = call_graph_node.GetCallSite(instruction);

View File

@ -83,12 +83,14 @@ class HloRematerialization : public HloModulePass {
explicit HloRematerialization(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
RematerializationSizes* sizes, RematerializationPass pass_location,
int block_size_limit,
CompactShapeFunction compact_shape_function = nullptr,
RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
: size_function_(size_function),
memory_limit_bytes_(memory_limit_bytes),
sizes_(sizes),
pass_location_(pass_location),
block_size_limit_(block_size_limit),
compact_shape_function_(compact_shape_function == nullptr
? DefaultCompactShapeFunction
: std::move(compact_shape_function)),
@ -144,6 +146,10 @@ class HloRematerialization : public HloModulePass {
// multi-output fusion.
RematerializationPass pass_location_;
// Maximum number of consecutive instructions to consider for
// rematerialization.
int block_size_limit_;
// Converts a shape into compact form, returns the same shape if a shape is
// already considered compact.
const CompactShapeFunction compact_shape_function_;

View File

@ -50,7 +50,8 @@ class HloRematerializationTest : public RematerializationTestBase {
HloRematerialization remat(
ByteSizeOf, memory_limit_bytes,
/*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion);
HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1);
return remat.Run(module);
}
};
@ -582,7 +583,7 @@ class CompressingRematerializationTest : public RematerializationTestBase {
ShapeSizePadMinorTo64, memory_limit_bytes,
/*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion,
ChooseCompactLayoutForShape);
/*block_size_limit=*/1, ChooseCompactLayoutForShape);
return remat.Run(module);
}
};