Express remat: Only consider remat with nodes that create big buffers.

Use a skip list in remat pass to only go through certain skip nodes
when considering remat candidates beyond certain size.

Note that this algorithm is still quadratic, but now quadratic to the number of big buffers.

PiperOrigin-RevId: 333395527
Change-Id: Ib229e0499c10e66751131c384da401f9b63f1b70
This commit is contained in:
Yunxing Dai 2020-09-23 16:20:00 -07:00 committed by TensorFlower Gardener
parent f3462c311d
commit c1b165b68c
3 changed files with 229 additions and 51 deletions

View File

@ -135,12 +135,17 @@ struct Item {
// The buffers used by this instruction.
BufferIdList buffers_used;
bool is_skip_node = false;
private:
friend class InstructionList;
// Items are arranged in a doubly linked list.
Item* next;
Item* prev;
Item* next = nullptr;
Item* prev = nullptr;
Item* prev_skip_node = nullptr;
Item* next_skip_node = nullptr;
// List is ordered by position, which can however be duplicated as
// new instructions are inserted. See InsertBeforeInstructions
@ -152,11 +157,23 @@ using ItemList = absl::InlinedVector<Item*, 3>;
// Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements.
//
// This is a skip list structure that has two lanes: express lane and slow lane.
// All nodes are presented on the slow lane but a node can be promoted into
// express lane for fast iteration.
//
// In the following case, node 2 and node + 1 are connected via an express lane.
// +--------------------------+----------->: Express lane
// | |
// node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane
//
class InstructionList {
public:
explicit InstructionList(const HloInstructionSequence& order) {
int64 position = 0;
Item* last = nullptr;
last_skip_node_ = nullptr;
first_skip_node_ = nullptr;
for (HloInstruction* inst : order.instructions()) {
// Add a new item to the linked list.
Item* item = new Item;
@ -198,6 +215,9 @@ class InstructionList {
Item* first() const { return first_; }
Item* next(Item* item) const { return item->next; }
Item* first_skip_node() const { return first_skip_node_; }
Item* next_skip_node(Item* item) const { return item->next_skip_node; }
// Creates an Item for the given instruction, but doesn't add it to the list.
// (Use InsertBeforeInstructions to add the Item to the list.)
Item* CreateItem(HloInstruction* inst) {
@ -266,6 +286,27 @@ class InstructionList {
return InsertBefore(to_insert, min_position_item);
}
// Scan the list and promote nodes to express lane if should_promote(Item)
// returns true;
void PromoteNodesToSkip(std::function<bool(Item*)> should_promote) {
int64 count = 0;
for (auto* item = first(); item != nullptr; item = next(item)) {
if (should_promote(item)) {
count += 1;
if (first_skip_node_ == nullptr) {
first_skip_node_ = item;
}
item->is_skip_node = true;
item->prev_skip_node = last_skip_node_;
if (last_skip_node_ != nullptr) {
last_skip_node_->next_skip_node = item;
}
last_skip_node_ = item;
}
}
VLOG(1) << " Rematerialization has " << count << " items in express lane";
}
void InsertAfterInstructions(Item* to_insert,
absl::Span<Item* const> after_instructions) {
VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
@ -301,6 +342,44 @@ class InstructionList {
void InsertBefore(Item* item, Item* before) {
VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
<< before->instruction->name();
// Always place new nodes on express lane for the ease of implementation.
item->is_skip_node = true;
// Find the next express node starting from 'before'. Set up the node's
// express pointers.
Item* cursor = before;
while (cursor != nullptr && !cursor->is_skip_node) {
cursor = cursor->next;
}
CHECK(cursor == nullptr || cursor->is_skip_node);
if (cursor == nullptr) {
//
// last_skip_node_<---+ : express lane
// |
// ...<->`item`<-> .. <-> `cursor`(null) : slow lane
//
// Reached the end. Set the prev_express to last_skip_node, and reset
// last_skip.
item->prev_skip_node = last_skip_node_;
item->next_skip_node = nullptr;
last_skip_node_ = item;
} else {
//
// <-+------------+----------------+---------> : express lane
// | | |
// prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane
//
// Reached the next skip node, sets up express pointers accordingly.
CHECK(cursor->is_skip_node);
item->prev_skip_node = cursor->prev_skip_node;
if (item->prev_skip_node != nullptr) {
item->prev_skip_node->next_skip_node = item;
}
item->next_skip_node = cursor;
cursor->prev_skip_node = item;
}
if (first_skip_node_ == cursor) {
first_skip_node_ = item;
}
// Insert new item into linked list.
item->prev = before->prev;
item->next = before;
@ -319,6 +398,12 @@ class InstructionList {
Item* first_;
// First skip node of this list.
Item* first_skip_node_;
// Last skip node of this list.
Item* last_skip_node_;
// Item for each instruction.
absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
};
@ -460,6 +545,15 @@ class MemoryUsageTracker {
// values.
int64 memory_usage() const { return memory_usage_; }
//
int64 AllocatedSize(Item* item) const {
int64 size = 0;
for (auto buffer_id : item->buffers_defined) {
size += AllocatedSize(buffer_id);
}
return size;
}
// Check invariants of the data structure. This is expensive to call.
bool Check() const;
@ -652,7 +746,6 @@ MemoryUsageTracker::MemoryUsageTracker(
.CreateFlattenedSet();
absl::flat_hash_map<const LogicalBuffer*, BufferId>
logical_buffer_to_buffer_id;
for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* const instruction = item->instruction;
@ -1186,8 +1279,9 @@ MemoryUsageTracker::PickRematerializationCandidates(
VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
<< max_block_size << "]";
for (auto* start_item = instruction_list.first(); start_item != nullptr;
start_item = instruction_list.next(start_item)) {
for (auto* start_item = instruction_list.first_skip_node();
start_item != nullptr;
start_item = instruction_list.next_skip_node(start_item)) {
std::vector<Item*> block =
GetInitialBlock(instruction_list, *this, start_item, min_block_size);
if (block.size() < min_block_size) {
@ -1566,7 +1660,7 @@ StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
StatusOr<bool> HloRematerialization::RematerializeComputation(
HloComputation* computation, HloSchedule* schedule,
int64 memory_limit_bytes) {
int64 memory_limit_bytes, int64 min_remat_size) {
VLOG(1) << "Rematerializing computation " << computation->name()
<< " with limit " << HumanReadableNumBytes(memory_limit_bytes);
VLOG(1) << "peak memory usage is "
@ -1577,6 +1671,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
MemoryUsageTracker memory_tracker(
computation, size_function_, compact_shape_function_,
*points_to_analysis_, instruction_list, mode_);
instruction_list.PromoteNodesToSkip([&](Item* item) {
return memory_tracker.AllocatedSize(item) >= min_remat_size;
});
bool changed = false;
// If the rematerialization makes the source instruction dead, then the
@ -1622,43 +1720,46 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// single instruction rematerialization is considered first.
int min_block_size = 1;
int max_block_size = 1;
// Only trigger rematerialization when the memory usage changes.
if (memory_tracker.AllocatedSize(item) + callee_usage > 0) {
while (memory_tracker.memory_usage() + callee_usage >
memory_limit_bytes) {
VLOG(2) << "Over memory limit at instruction " << instruction->name()
<< ", using "
<< HumanReadableNumBytes(memory_tracker.memory_usage() +
callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
VLOG(2) << "Over memory limit at instruction " << instruction->name()
<< ", using "
<< HumanReadableNumBytes(memory_tracker.memory_usage() +
callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
TF_ASSIGN_OR_RETURN(
InstructionsAdded instructions_added,
RematerializeBestBlock(min_block_size, max_block_size,
&memory_tracker, &instruction_list,
memory_limit_bytes, &rematerializable_map,
&remat_move_instructions));
net_instructions_added += instructions_added.net_instructions_added;
remat_count += instructions_added.remat_count;
TF_ASSIGN_OR_RETURN(InstructionsAdded instructions_added,
RematerializeBestBlock(
min_block_size, max_block_size, &memory_tracker,
&instruction_list, memory_limit_bytes,
&rematerializable_map, &remat_move_instructions));
net_instructions_added += instructions_added.net_instructions_added;
remat_count += instructions_added.remat_count;
VLOG(1) << "memory_usage after rematerialization = "
<< HumanReadableNumBytes(memory_tracker.memory_usage());
if (instructions_added.remat_count == 0) {
// Unable to find a block to rematerialize.
// Consider doubling the block size.
min_block_size = max_block_size + 1;
max_block_size = 2 * max_block_size;
} else {
// Found a valid block. Reset to start looking for single instructions
// again.
max_rematerialized_block_size_ =
std::max(max_rematerialized_block_size_, max_block_size);
changed = true;
min_block_size = 1;
max_block_size = 1;
}
if (max_block_size > block_size_limit_) {
break;
VLOG(1) << "memory_usage after rematerialization = "
<< HumanReadableNumBytes(memory_tracker.memory_usage());
if (instructions_added.remat_count == 0) {
// Unable to find a block to rematerialize.
// Consider doubling the block size.
min_block_size = max_block_size + 1;
max_block_size = 2 * max_block_size;
} else {
// Found a valid block. Reset to start looking for single instructions
// again.
max_rematerialized_block_size_ =
std::max(max_rematerialized_block_size_, max_block_size);
changed = true;
min_block_size = 1;
max_block_size = 1;
}
if (max_block_size > block_size_limit_) {
break;
}
}
}
const CallSite* callsite = call_graph_node.GetCallSite(instruction);
if (callsite != nullptr &&
callsite->context() == CallContext::kSequential &&
@ -1683,10 +1784,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
TF_ASSIGN_OR_RETURN(
bool subcomputation_changed,
RematerializeComputation(called_computation, schedule,
subcomputation_memory_limit_bytes));
subcomputation_memory_limit_bytes,
min_remat_size));
changed |= subcomputation_changed;
}
}
TF_ASSIGN_OR_RETURN(callee_usage,
CalledComputationsMemoryUsage(instruction));
}
@ -1786,14 +1889,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
module_output_size;
VLOG(1) << "Peak memory usage of module (before): "
<< HumanReadableNumBytes(before_peak_memory);
// Subcomputations called by the entry computation will also be
// rematerialized.
TF_ASSIGN_OR_RETURN(
bool changed,
RematerializeComputation(module->entry_computation(), &module->schedule(),
adjusted_memory_limit_bytes));
adjusted_memory_limit_bytes, min_remat_size_));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@ -1838,7 +1939,6 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
}

View File

@ -85,7 +85,8 @@ class HloRematerialization : public HloModulePass {
RematerializationSizes* sizes, RematerializationPass pass_location,
int block_size_limit,
CompactShapeFunction compact_shape_function = nullptr,
RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
RematerializationMode mode = RematerializationMode::kRecomputeAndCompress,
int64 min_remat_size = 0)
: size_function_(size_function),
memory_limit_bytes_(memory_limit_bytes),
sizes_(sizes),
@ -94,7 +95,8 @@ class HloRematerialization : public HloModulePass {
compact_shape_function_(compact_shape_function == nullptr
? DefaultCompactShapeFunction
: std::move(compact_shape_function)),
mode_(mode) {}
mode_(mode),
min_remat_size_(min_remat_size) {}
~HloRematerialization() override = default;
absl::string_view name() const override { return "rematerialization"; }
@ -114,7 +116,8 @@ class HloRematerialization : public HloModulePass {
// and inserted into 'order'.
virtual StatusOr<bool> RematerializeComputation(HloComputation* computation,
HloSchedule* schedule,
int64 memory_limit_bytes);
int64 memory_limit_bytes,
int64 min_remat_size);
// Computes and returns the peak memory used by the given computation. The
// peak memory is the maximum total size of all live HLO instruction values at
@ -185,6 +188,8 @@ class HloRematerialization : public HloModulePass {
int max_rematerialized_block_size_ = 0;
RematerializationMode mode_;
int64 min_remat_size_;
};
} // namespace xla

View File

@ -41,7 +41,8 @@ using ::testing::_;
class HloRematerializationTest : public RematerializationTestBase {
protected:
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module) {
HloModule* module,
int64 min_remat_size = 0) {
TF_EXPECT_OK(verifier().Run(module).status());
HloMemoryScheduler scheduler(
[](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
@ -51,7 +52,9 @@ class HloRematerializationTest : public RematerializationTestBase {
ByteSizeOf, memory_limit_bytes,
/*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1);
/*block_size_limit=*/1, nullptr,
HloRematerialization::RematerializationMode::kRecomputeAndCompress,
min_remat_size);
return remat.Run(module);
}
};
@ -96,6 +99,26 @@ TEST_F(HloRematerializationTest, SingleComputation) {
remat_bcast);
}
// Test rematerialization of a single computation that contains nodes that
// doesn't contain node worth using remat.
TEST_F(HloRematerializationTest, SingleComputationNoWorthRemat) {
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
// Set the minimum remat size to 14KiB, meaning no nodes should be remat.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024, module.get(),
/*min_remat_size=*/14 * 1024));
EXPECT_FALSE(changed);
}
// Test rematerialization of a single computation produced by
// MakeRematerializableComputation but with a sufficiently high memory limit
// such that no instructions are rematerialized.
@ -577,17 +600,67 @@ class CompressingRematerializationTest : public RematerializationTestBase {
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module) {
HloModule* module,
int64 min_remat_size = 0) {
TF_EXPECT_OK(verifier().Run(module).status());
HloRematerialization remat(
ShapeSizePadMinorTo64, memory_limit_bytes,
/*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1, ChooseCompactLayoutForShape);
/*block_size_limit=*/1, ChooseCompactLayoutForShape,
HloRematerialization::RematerializationMode::kCompressOnly,
min_remat_size);
return remat.Run(module);
}
};
// Test rematerialization only remats big buffer that pass certain limits.
TEST_F(CompressingRematerializationTest, OnlyRematBigBuffer) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%broadcast.1 = f32[10,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
// Only rematerialize buffers which have shaep f32[64, 2]. Buffers with shape
// f32[10, 2] are ignored.
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024,
module.get(), 10 * 1024));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
HloInstruction* broadcast_2 =
module->entry_computation()->GetInstructionWithName("broadcast.1");
HloInstruction* reduce =
module->entry_computation()->GetInstructionWithName("reduce.1");
HloInstruction* reduce_2 =
module->entry_computation()->GetInstructionWithName("reduce.2");
EXPECT_THAT(reduce,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
EXPECT_THAT(reduce_2, op::Reduce(broadcast_2, op::Constant()));
}
// Test rematerialization of a single instruction.
TEST_F(CompressingRematerializationTest, SingleRemat) {
const string& hlo_string = R"(