Implements an incremental greedy algorithm for rematerialization. Whenever the memory use at any point in the execution goes over the memory limit, we attempt to rematerialize one instruction at a time to reduce memory usage. If no such candidate is found, we expand the search space to instead rematerialize 2 consecutive instructions at a time, and so on doubling the search space every time. At any point, if we find a block of instructions to rematerialize that reduces the memory usage, we reset the block size to 1 to resume the algorithm.
PiperOrigin-RevId: 293246777 Change-Id: I7488276578e56d1b35b0e963d76b24adefacff8d
This commit is contained in:
parent
e7dca8b51c
commit
63224a7501
@ -86,13 +86,13 @@ bool IsRematerializable(const HloInstruction* instruction) {
|
||||
// cache before, and eventually calling the IsRematerializable() API.
|
||||
bool CanBeRematerialized(
|
||||
const HloInstruction* instruction,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
|
||||
auto it = remat_able->find(instruction);
|
||||
if (it != remat_able->end()) {
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
|
||||
auto it = rematerializable_map->find(instruction);
|
||||
if (it != rematerializable_map->end()) {
|
||||
return it->second;
|
||||
}
|
||||
bool rematerializable = IsRematerializable(instruction);
|
||||
(*remat_able)[instruction] = rematerializable;
|
||||
(*rematerializable_map)[instruction] = rematerializable;
|
||||
return rematerializable;
|
||||
}
|
||||
|
||||
@ -381,14 +381,22 @@ class MemoryUsageTracker {
|
||||
// EndInstruction memory for dead operand(s) is freed.
|
||||
Status BeginInstruction(Item* item);
|
||||
|
||||
int64 RematerializationCost(const HloInstruction* instruction,
|
||||
int64 RematerializationCost(const std::vector<Item*>& items,
|
||||
int64 memory_reduced, int64 memory_limit_bytes) {
|
||||
// If none of the users of 'instruction' have been placed in the sequence
|
||||
// (as tracked by memory_tracker), then rematerialization of 'instruction'
|
||||
// is a zero-cost move of 'instruction' in the sequence.
|
||||
if (!absl::c_any_of(
|
||||
instruction->users(),
|
||||
[this](const HloInstruction* inst) { return IsPlaced(inst); })) {
|
||||
// If none of the users of any 'item' have been placed in the
|
||||
// sequence (as tracked by memory_tracker), then rematerialization of
|
||||
// 'item' is a zero-cost move of 'item->instruction' in the sequence.
|
||||
bool zero_cost_move = true;
|
||||
for (auto* item : items) {
|
||||
auto* instruction = item->instruction;
|
||||
if (absl::c_any_of(
|
||||
instruction->users(),
|
||||
[this](const HloInstruction* inst) { return IsPlaced(inst); })) {
|
||||
zero_cost_move = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (zero_cost_move) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -425,9 +433,16 @@ class MemoryUsageTracker {
|
||||
// to uses).
|
||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
|
||||
|
||||
std::pair<Item*, RematStrategy> PickRematerializationCandidate(
|
||||
// Selects and returns the best candidate instructions for rematerialization.
|
||||
// A sequence of candidate instructions of length between min_block_size and
|
||||
// max_block_size (both inclusive) with the lowest rematerialization cost is
|
||||
// selected among those candidates which reduce memory use at the program
|
||||
// point of the current instruction as indicated by memory_tracker. Returns an
|
||||
// empty vector if no candidates are found.
|
||||
std::pair<std::vector<Item*>, RematStrategy> PickRematerializationCandidates(
|
||||
const InstructionList& instruction_list, int64 memory_limit_bytes,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
|
||||
int min_block_size, int max_block_size);
|
||||
|
||||
// Returns whether the given instruction has been placed (BeginInstruction
|
||||
// has been called with 'instruction' as the argument).
|
||||
@ -438,6 +453,9 @@ class MemoryUsageTracker {
|
||||
// Returns whether 'item' has any unplaced users.
|
||||
bool HasUnplacedUsers(Item* item) const;
|
||||
|
||||
// Returns whether 'item' is currently in progress.
|
||||
bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
|
||||
|
||||
// Returns the current memory usage. This is the sum of sizes of all live
|
||||
// values.
|
||||
int64 memory_usage() const { return memory_usage_; }
|
||||
@ -1121,115 +1139,166 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
||||
return memory_limit_bytes / memory_reduced;
|
||||
}
|
||||
|
||||
// Selects and returns the best candidate instruction for rematerialization.
|
||||
// The instruction with lowest rematerialization cost is selected among those
|
||||
// candidate which reduce memory use at the program point of the current
|
||||
// instruction as indicated by memory_tracker. nullptr is returned if no
|
||||
// candidate can be found.
|
||||
std::pair<Item*, RematStrategy>
|
||||
MemoryUsageTracker::PickRematerializationCandidate(
|
||||
// Returns a block of up to min_block_size consecutive candidate instructions
|
||||
// from instruction_list starting from start_item. Returns fewer than
|
||||
// min_block_size instructions if the block of unplaced instructions starting
|
||||
// from start_item is smaller than min_block_size.
|
||||
std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
|
||||
const MemoryUsageTracker& tracker,
|
||||
Item* start_item, int min_block_size) {
|
||||
std::vector<Item*> item_block;
|
||||
Item* curr_item = start_item;
|
||||
for (int i = 0; i < min_block_size; ++i) {
|
||||
if (curr_item == nullptr || !curr_item->placed ||
|
||||
tracker.IsInProgressItem(curr_item)) {
|
||||
break;
|
||||
}
|
||||
item_block.push_back(curr_item);
|
||||
curr_item = instruction_list.next(curr_item);
|
||||
}
|
||||
return item_block;
|
||||
}
|
||||
|
||||
// Returns whether any instruction in 'block' is blacklisted or
|
||||
// non-rematerializable.
|
||||
bool AnyBlacklistedOrNonRematerializable(
|
||||
const std::vector<Item*>& block,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
|
||||
for (auto* item : block) {
|
||||
if (item->blacklisted) {
|
||||
return true;
|
||||
}
|
||||
if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<std::vector<Item*>, RematStrategy>
|
||||
MemoryUsageTracker::PickRematerializationCandidates(
|
||||
const InstructionList& instruction_list, int64 memory_limit_bytes,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
|
||||
Item* best_item = nullptr;
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
|
||||
int min_block_size, int max_block_size) {
|
||||
std::vector<Item*> best_items;
|
||||
int64 best_cost = 0;
|
||||
RematStrategy best_strategy;
|
||||
|
||||
VLOG(5) << "Picking candidate";
|
||||
VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
|
||||
<< max_block_size << "]";
|
||||
|
||||
// TODO(b/35244891): This is currently quadratic in the number of HLO
|
||||
// instructions.
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
item = instruction_list.next(item)) {
|
||||
if (!item->placed) {
|
||||
// Only iterate up to the currently placed instruction.
|
||||
// We are trying to reduce memory usage at the placed
|
||||
// instruction so rematerializing later values is of no benefit.
|
||||
for (auto* start_item = instruction_list.first(); start_item != nullptr;
|
||||
start_item = instruction_list.next(start_item)) {
|
||||
std::vector<Item*> block =
|
||||
GetInitialBlock(instruction_list, *this, start_item, min_block_size);
|
||||
if (block.size() < min_block_size) {
|
||||
// There are no more blocks of size at least min_block_size with unplaced
|
||||
// instructions.
|
||||
break;
|
||||
}
|
||||
HloInstruction* candidate = item->instruction;
|
||||
VLOG(5) << "considering rematerialization candidate " << candidate->name();
|
||||
|
||||
if (item->blacklisted) {
|
||||
// Skip instructions on the blacklist to avoid infinite loops of
|
||||
// rematerializing the same instruction(s) repeatedly.
|
||||
VLOG(5) << "candidate " << candidate->name()
|
||||
<< " is excluded from rematerialization";
|
||||
// If any item in the starting block are blacklisted or non-rematable, then
|
||||
// break and move on to next start_item (we can actually move to the last
|
||||
// invalid item in this block, but let's ignore that optimization for now).
|
||||
if (AnyBlacklistedOrNonRematerializable(block, rematerializable_map)) {
|
||||
continue;
|
||||
}
|
||||
if (!CanBeRematerialized(candidate, remat_able)) {
|
||||
VLOG(5) << "candidate " << candidate->name()
|
||||
<< " not viable: is not rematerializable";
|
||||
while (block.size() <= max_block_size) {
|
||||
// block size = 1 is treated separately since we consider compression in
|
||||
// this case only.
|
||||
if (block.size() == 1) {
|
||||
auto* item = block[0];
|
||||
auto* candidate = item->instruction;
|
||||
if (item->buffers_output.size() == 1 &&
|
||||
(mode_ ==
|
||||
HloRematerialization::RematerializationMode::kCompressOnly ||
|
||||
mode_ == HloRematerialization::RematerializationMode::
|
||||
kRecomputeAndCompress)) {
|
||||
// Only consider compressing single output instruction.
|
||||
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item->buffers_output.size() == 1 &&
|
||||
(mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
|
||||
mode_ == HloRematerialization::RematerializationMode::
|
||||
kRecomputeAndCompress)) {
|
||||
// Only consider compressing single output instruction.
|
||||
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
|
||||
|
||||
if (item->placed && item != in_progress_item_ &&
|
||||
!output_buffer.live_out) {
|
||||
const Shape& original_shape = item->instruction->shape();
|
||||
if (original_shape.IsArray()) {
|
||||
Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
|
||||
const int64 memory_reduced =
|
||||
MemoryReducedIfCompressed(item, compact_shape);
|
||||
if (memory_reduced > 0) {
|
||||
const int64 cost = memory_limit_bytes / memory_reduced;
|
||||
if (best_item == nullptr || cost < best_cost) {
|
||||
VLOG(3) << "candidate " << candidate->name() << "("
|
||||
<< candidate->ToShortString() << ")"
|
||||
<< " now best when compressed into "
|
||||
<< compact_shape.ToString(true);
|
||||
RematStrategy strategy;
|
||||
strategy.kind = RematStrategy::kCompress;
|
||||
best_strategy = strategy;
|
||||
best_strategy.compact_shape = compact_shape;
|
||||
best_item = item;
|
||||
best_cost = cost;
|
||||
if (item->placed && item != in_progress_item_ &&
|
||||
!output_buffer.live_out) {
|
||||
const Shape& original_shape = item->instruction->shape();
|
||||
if (original_shape.IsArray()) {
|
||||
Shape compact_shape =
|
||||
GetCompactShape(item->instruction).ValueOrDie();
|
||||
const int64 memory_reduced =
|
||||
MemoryReducedIfCompressed(item, compact_shape);
|
||||
if (memory_reduced > 0) {
|
||||
const int64 cost = memory_limit_bytes / memory_reduced;
|
||||
if (best_items.empty() || cost < best_cost) {
|
||||
VLOG(3) << "candidate " << candidate->name() << "("
|
||||
<< candidate->ToShortString() << ")"
|
||||
<< " now best when compressed into "
|
||||
<< compact_shape.ToString(true);
|
||||
RematStrategy strategy;
|
||||
strategy.kind = RematStrategy::kCompress;
|
||||
best_strategy = strategy;
|
||||
best_strategy.compact_shape = compact_shape;
|
||||
best_items = block;
|
||||
best_cost = cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If any of the candidate's control successor has been placed, we need
|
||||
// to skip this candidate. Otherwise we will violate control dependency.
|
||||
bool control_successor_placed = std::any_of(
|
||||
candidate->control_successors().begin(),
|
||||
candidate->control_successors().end(),
|
||||
[this](const HloInstruction* inst) { return IsPlaced(inst); });
|
||||
|
||||
if (control_successor_placed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not consider recomputation in compress-only mode.
|
||||
if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 memory_reduced = MemoryReducedIfRematerialized(item);
|
||||
|
||||
if (memory_reduced > 0) {
|
||||
const int cost =
|
||||
RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
|
||||
|
||||
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
|
||||
<< memory_reduced << ", cost per byte " << cost;
|
||||
|
||||
if (best_item == nullptr || cost < best_cost) {
|
||||
VLOG(5) << "candidate " << candidate->name() << " now best";
|
||||
best_strategy.kind = RematStrategy::kRecompute;
|
||||
best_item = item;
|
||||
best_cost = cost;
|
||||
// Do not consider recomputation in compress-only mode.
|
||||
if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
|
||||
// break out of this loop. Move on to the next start_item.
|
||||
break;
|
||||
}
|
||||
// If any of the candidate's control successor has been placed, we need
|
||||
// to skip this candidate. Otherwise we will violate control dependency.
|
||||
bool control_successor_placed = false;
|
||||
for (auto* item : block) {
|
||||
HloInstruction* candidate = item->instruction;
|
||||
if (std::any_of(candidate->control_successors().begin(),
|
||||
candidate->control_successors().end(),
|
||||
[this](const HloInstruction* inst) {
|
||||
return IsPlaced(inst);
|
||||
})) {
|
||||
control_successor_placed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (control_successor_placed) {
|
||||
// break out of this loop. Move on to the next start_item.
|
||||
break;
|
||||
}
|
||||
const int64 memory_reduced = MemoryReducedIfRematerialized(block);
|
||||
|
||||
if (memory_reduced > 0) {
|
||||
const int cost =
|
||||
RematerializationCost(block, memory_reduced, memory_limit_bytes);
|
||||
|
||||
VLOG(5) << "Candidate block of size " << block.size()
|
||||
<< " starting from " << block[0]->instruction->name()
|
||||
<< ", memory reduced " << memory_reduced << ", cost per byte "
|
||||
<< cost;
|
||||
|
||||
if (best_items.empty() || cost < best_cost) {
|
||||
VLOG(5) << "Candidate block of size " << block.size()
|
||||
<< " starting from " << block[0]->instruction->name()
|
||||
<< " now best";
|
||||
best_strategy.kind = RematStrategy::kRecompute;
|
||||
best_items = block;
|
||||
best_cost = cost;
|
||||
}
|
||||
}
|
||||
|
||||
// Time to update the block to include the next instruction.
|
||||
auto* last_item = block[block.size() - 1];
|
||||
auto* next_item = instruction_list.next(last_item);
|
||||
if (next_item == nullptr || next_item->blacklisted ||
|
||||
!next_item->placed || next_item == in_progress_item_ ||
|
||||
!CanBeRematerialized(next_item->instruction, rematerializable_map)) {
|
||||
break;
|
||||
}
|
||||
block.push_back(next_item);
|
||||
}
|
||||
}
|
||||
return {best_item, best_strategy};
|
||||
return {best_items, best_strategy};
|
||||
}
|
||||
|
||||
bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
|
||||
@ -1402,6 +1471,60 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
|
||||
return 2;
|
||||
}
|
||||
|
||||
// A simple struct to encapsulate the number of instructions added during
|
||||
// rematerialization.
|
||||
struct InstructionsAdded {
|
||||
// Total count of instructions rematerialized.
|
||||
int remat_count;
|
||||
// Total count of instructions rematerialized minus number of original
|
||||
// instructions that are now dead.
|
||||
int net_instructions_added;
|
||||
};
|
||||
|
||||
// Rematerializes the best block of instructions of size between min_block_size
|
||||
// and max_block_size (both inclusive) if at least one candidate block of
|
||||
// instructions can be found. Returns number of instructions rematerialized.
|
||||
StatusOr<InstructionsAdded> RematerializeBestBlock(
|
||||
int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
|
||||
InstructionList* instruction_list, int64 memory_limit_bytes,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
|
||||
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions) {
|
||||
CHECK(min_block_size > 0) << "Negative block size.";
|
||||
|
||||
std::vector<Item*> best_items;
|
||||
RematStrategy best_strategy;
|
||||
std::tie(best_items, best_strategy) =
|
||||
memory_tracker->PickRematerializationCandidates(
|
||||
*instruction_list, memory_limit_bytes, rematerializable_map,
|
||||
min_block_size, max_block_size);
|
||||
InstructionsAdded num_instructions_added;
|
||||
num_instructions_added.remat_count = best_items.size();
|
||||
if (best_items.empty()) {
|
||||
num_instructions_added.net_instructions_added = 0;
|
||||
return num_instructions_added;
|
||||
}
|
||||
|
||||
if (best_strategy.kind == RematStrategy::kCompress) {
|
||||
CHECK(best_items.size() == 1)
|
||||
<< "More than one instruction compressed simultaneously.";
|
||||
HloInstruction* best = best_items[0]->instruction;
|
||||
VLOG(1) << "Compressing instruction " << best->name() << " (saving "
|
||||
<< HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
|
||||
best_items[0], best_strategy.compact_shape))
|
||||
<< ")";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
num_instructions_added.net_instructions_added,
|
||||
CompressInstruction(memory_tracker, best_items[0],
|
||||
best_strategy.compact_shape, instruction_list));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
num_instructions_added.net_instructions_added,
|
||||
RematerializeInstructions(memory_tracker, &best_items,
|
||||
remat_move_instructions, instruction_list));
|
||||
}
|
||||
return num_instructions_added;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
||||
@ -1465,7 +1588,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
|
||||
|
||||
// The map from instructions to their rematerializable status.
|
||||
absl::flat_hash_map<const HloInstruction*, bool> remat_able;
|
||||
absl::flat_hash_map<const HloInstruction*, bool> rematerializable_map;
|
||||
|
||||
// The peak memory of the computation at any point in the instruction
|
||||
// sequence.
|
||||
@ -1496,6 +1619,11 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
<< "/" << instruction_list.size() << "]";
|
||||
instruction_index++;
|
||||
|
||||
// Initialize both min_block_size and max_block_size to 1 so that only
|
||||
// single instruction rematerialization is considered first.
|
||||
int min_block_size = 1;
|
||||
int max_block_size = 1;
|
||||
|
||||
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
|
||||
VLOG(2) << "Over memory limit at instruction " << instruction->name()
|
||||
<< ", using "
|
||||
@ -1503,53 +1631,31 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
callee_usage)
|
||||
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
|
||||
|
||||
Item* best_item;
|
||||
RematStrategy best_strategy;
|
||||
std::tie(best_item, best_strategy) =
|
||||
memory_tracker.PickRematerializationCandidate(
|
||||
instruction_list, memory_limit_bytes, &remat_able);
|
||||
|
||||
if (best_item == nullptr) {
|
||||
VLOG(3) << "Unable to find rematerialization candidate at program "
|
||||
"point "
|
||||
<< instruction->name() << ". Memory usage = "
|
||||
<< HumanReadableNumBytes(memory_tracker.memory_usage() +
|
||||
callee_usage);
|
||||
break;
|
||||
}
|
||||
|
||||
HloInstruction* best = best_item->instruction;
|
||||
changed = true;
|
||||
remat_count++;
|
||||
|
||||
int64 num_instructions_added = 0;
|
||||
if (best_strategy.kind == RematStrategy::kCompress) {
|
||||
VLOG(1) << "Compressing instruction " << best->name() << " (saving "
|
||||
<< HumanReadableNumBytes(
|
||||
memory_tracker.MemoryReducedIfCompressed(
|
||||
best_item, best_strategy.compact_shape))
|
||||
<< ")";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(num_instructions_added,
|
||||
CompressInstruction(&memory_tracker, best_item,
|
||||
best_strategy.compact_shape,
|
||||
&instruction_list));
|
||||
} else {
|
||||
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
|
||||
<< HumanReadableNumBytes(
|
||||
memory_tracker.MemoryReducedIfRematerialized(best_item))
|
||||
<< ")";
|
||||
|
||||
std::vector<Item*> best_items{best_item};
|
||||
TF_ASSIGN_OR_RETURN(num_instructions_added,
|
||||
RematerializeInstructions(
|
||||
&memory_tracker, &best_items,
|
||||
&remat_move_instructions, &instruction_list));
|
||||
}
|
||||
net_instructions_added += num_instructions_added;
|
||||
TF_ASSIGN_OR_RETURN(InstructionsAdded instructions_added,
|
||||
RematerializeBestBlock(
|
||||
min_block_size, max_block_size, &memory_tracker,
|
||||
&instruction_list, memory_limit_bytes,
|
||||
&rematerializable_map, &remat_move_instructions));
|
||||
net_instructions_added += instructions_added.net_instructions_added;
|
||||
remat_count += instructions_added.remat_count;
|
||||
|
||||
VLOG(1) << "memory_usage after rematerialization = "
|
||||
<< HumanReadableNumBytes(memory_tracker.memory_usage());
|
||||
if (instructions_added.remat_count == 0) {
|
||||
// Unable to find a block to rematerialize.
|
||||
// Consider doubling the block size.
|
||||
min_block_size = max_block_size + 1;
|
||||
max_block_size = 2 * max_block_size;
|
||||
} else {
|
||||
// Found a valid block. Reset to start looking for single instructions
|
||||
// again.
|
||||
changed = true;
|
||||
min_block_size = 1;
|
||||
max_block_size = 1;
|
||||
}
|
||||
if (max_block_size > block_size_limit_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const CallSite* callsite = call_graph_node.GetCallSite(instruction);
|
||||
|
@ -83,12 +83,14 @@ class HloRematerialization : public HloModulePass {
|
||||
explicit HloRematerialization(
|
||||
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
|
||||
RematerializationSizes* sizes, RematerializationPass pass_location,
|
||||
int block_size_limit,
|
||||
CompactShapeFunction compact_shape_function = nullptr,
|
||||
RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
|
||||
: size_function_(size_function),
|
||||
memory_limit_bytes_(memory_limit_bytes),
|
||||
sizes_(sizes),
|
||||
pass_location_(pass_location),
|
||||
block_size_limit_(block_size_limit),
|
||||
compact_shape_function_(compact_shape_function == nullptr
|
||||
? DefaultCompactShapeFunction
|
||||
: std::move(compact_shape_function)),
|
||||
@ -144,6 +146,10 @@ class HloRematerialization : public HloModulePass {
|
||||
// multi-output fusion.
|
||||
RematerializationPass pass_location_;
|
||||
|
||||
// Maximum number of consecutive instructions to consider for
|
||||
// rematerialization.
|
||||
int block_size_limit_;
|
||||
|
||||
// Converts a shape into compact form, returns the same shape if a shape is
|
||||
// already considered compact.
|
||||
const CompactShapeFunction compact_shape_function_;
|
||||
|
@ -50,7 +50,8 @@ class HloRematerializationTest : public RematerializationTestBase {
|
||||
HloRematerialization remat(
|
||||
ByteSizeOf, memory_limit_bytes,
|
||||
/*sizes=*/nullptr,
|
||||
HloRematerialization::RematerializationPass::kPreFusion);
|
||||
HloRematerialization::RematerializationPass::kPreFusion,
|
||||
/*block_size_limit=*/1);
|
||||
return remat.Run(module);
|
||||
}
|
||||
};
|
||||
@ -582,7 +583,7 @@ class CompressingRematerializationTest : public RematerializationTestBase {
|
||||
ShapeSizePadMinorTo64, memory_limit_bytes,
|
||||
/*sizes=*/nullptr,
|
||||
HloRematerialization::RematerializationPass::kPreFusion,
|
||||
ChooseCompactLayoutForShape);
|
||||
/*block_size_limit=*/1, ChooseCompactLayoutForShape);
|
||||
return remat.Run(module);
|
||||
}
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user