Express remat: Only consider remat with nodes that create big buffers.

Use a skip list in remat pass to only go through certain skip nodes
when considering remat candidates beyond certain size.

Note that this algorithm is still quadratic, but now quadratic to the number of big buffers.

PiperOrigin-RevId: 333395527
Change-Id: Ib229e0499c10e66751131c384da401f9b63f1b70
This commit is contained in:
Yunxing Dai 2020-09-23 16:20:00 -07:00 committed by TensorFlower Gardener
parent f3462c311d
commit c1b165b68c
3 changed files with 229 additions and 51 deletions

View File

@ -135,12 +135,17 @@ struct Item {
// The buffers used by this instruction. // The buffers used by this instruction.
BufferIdList buffers_used; BufferIdList buffers_used;
bool is_skip_node = false;
private: private:
friend class InstructionList; friend class InstructionList;
// Items are arranged in a doubly linked list. // Items are arranged in a doubly linked list.
Item* next; Item* next = nullptr;
Item* prev; Item* prev = nullptr;
Item* prev_skip_node = nullptr;
Item* next_skip_node = nullptr;
// List is ordered by position, which can however be duplicated as // List is ordered by position, which can however be duplicated as
// new instructions are inserted. See InsertBeforeInstructions // new instructions are inserted. See InsertBeforeInstructions
@ -152,11 +157,23 @@ using ItemList = absl::InlinedVector<Item*, 3>;
// Class which maintains an ordered list of instructions with fast insertion // Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements. // before arbitrary elements.
//
// This is a skip list structure that has two lanes: express lane and slow lane.
// All nodes are presented on the slow lane but a node can be promoted into
// express lane for fast iteration.
//
// In the following case, node 2 and node + 1 are connected via an express lane.
// +--------------------------+----------->: Express lane
// | |
// node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane
//
class InstructionList { class InstructionList {
public: public:
explicit InstructionList(const HloInstructionSequence& order) { explicit InstructionList(const HloInstructionSequence& order) {
int64 position = 0; int64 position = 0;
Item* last = nullptr; Item* last = nullptr;
last_skip_node_ = nullptr;
first_skip_node_ = nullptr;
for (HloInstruction* inst : order.instructions()) { for (HloInstruction* inst : order.instructions()) {
// Add a new item to the linked list. // Add a new item to the linked list.
Item* item = new Item; Item* item = new Item;
@ -198,6 +215,9 @@ class InstructionList {
Item* first() const { return first_; } Item* first() const { return first_; }
Item* next(Item* item) const { return item->next; } Item* next(Item* item) const { return item->next; }
Item* first_skip_node() const { return first_skip_node_; }
Item* next_skip_node(Item* item) const { return item->next_skip_node; }
// Creates an Item for the given instruction, but doesn't add it to the list. // Creates an Item for the given instruction, but doesn't add it to the list.
// (Use InsertBeforeInstructions to add the Item to the list.) // (Use InsertBeforeInstructions to add the Item to the list.)
Item* CreateItem(HloInstruction* inst) { Item* CreateItem(HloInstruction* inst) {
@ -266,6 +286,27 @@ class InstructionList {
return InsertBefore(to_insert, min_position_item); return InsertBefore(to_insert, min_position_item);
} }
// Scan the list and promote nodes to express lane if should_promote(Item)
// returns true;
void PromoteNodesToSkip(std::function<bool(Item*)> should_promote) {
int64 count = 0;
for (auto* item = first(); item != nullptr; item = next(item)) {
if (should_promote(item)) {
count += 1;
if (first_skip_node_ == nullptr) {
first_skip_node_ = item;
}
item->is_skip_node = true;
item->prev_skip_node = last_skip_node_;
if (last_skip_node_ != nullptr) {
last_skip_node_->next_skip_node = item;
}
last_skip_node_ = item;
}
}
VLOG(1) << " Rematerialization has " << count << " items in express lane";
}
void InsertAfterInstructions(Item* to_insert, void InsertAfterInstructions(Item* to_insert,
absl::Span<Item* const> after_instructions) { absl::Span<Item* const> after_instructions) {
VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name() VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
@ -301,6 +342,44 @@ class InstructionList {
void InsertBefore(Item* item, Item* before) { void InsertBefore(Item* item, Item* before) {
VLOG(3) << "InsertBefore: " << item->instruction->name() << " before " VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
<< before->instruction->name(); << before->instruction->name();
// Always place new nodes on express lane for the ease of implementation.
item->is_skip_node = true;
// Find the next express node starting from 'before'. Set up the node's
// express pointers.
Item* cursor = before;
while (cursor != nullptr && !cursor->is_skip_node) {
cursor = cursor->next;
}
CHECK(cursor == nullptr || cursor->is_skip_node);
if (cursor == nullptr) {
//
// last_skip_node_<---+ : express lane
// |
// ...<->`item`<-> .. <-> `cursor`(null) : slow lane
//
// Reached the end. Set the prev_express to last_skip_node, and reset
// last_skip.
item->prev_skip_node = last_skip_node_;
item->next_skip_node = nullptr;
last_skip_node_ = item;
} else {
//
// <-+------------+----------------+---------> : express lane
// | | |
// prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane
//
// Reached the next skip node, sets up express pointers accordingly.
CHECK(cursor->is_skip_node);
item->prev_skip_node = cursor->prev_skip_node;
if (item->prev_skip_node != nullptr) {
item->prev_skip_node->next_skip_node = item;
}
item->next_skip_node = cursor;
cursor->prev_skip_node = item;
}
if (first_skip_node_ == cursor) {
first_skip_node_ = item;
}
// Insert new item into linked list. // Insert new item into linked list.
item->prev = before->prev; item->prev = before->prev;
item->next = before; item->next = before;
@ -319,6 +398,12 @@ class InstructionList {
Item* first_; Item* first_;
// First skip node of this list.
Item* first_skip_node_;
// Last skip node of this list.
Item* last_skip_node_;
// Item for each instruction. // Item for each instruction.
absl::flat_hash_map<const HloInstruction*, Item*> item_map_; absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
}; };
@ -460,6 +545,15 @@ class MemoryUsageTracker {
// values. // values.
int64 memory_usage() const { return memory_usage_; } int64 memory_usage() const { return memory_usage_; }
//
int64 AllocatedSize(Item* item) const {
int64 size = 0;
for (auto buffer_id : item->buffers_defined) {
size += AllocatedSize(buffer_id);
}
return size;
}
// Check invariants of the data structure. This is expensive to call. // Check invariants of the data structure. This is expensive to call.
bool Check() const; bool Check() const;
@ -652,7 +746,6 @@ MemoryUsageTracker::MemoryUsageTracker(
.CreateFlattenedSet(); .CreateFlattenedSet();
absl::flat_hash_map<const LogicalBuffer*, BufferId> absl::flat_hash_map<const LogicalBuffer*, BufferId>
logical_buffer_to_buffer_id; logical_buffer_to_buffer_id;
for (auto* item = instruction_list_.first(); item != nullptr; for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) { item = instruction_list_.next(item)) {
const HloInstruction* const instruction = item->instruction; const HloInstruction* const instruction = item->instruction;
@ -1186,8 +1279,9 @@ MemoryUsageTracker::PickRematerializationCandidates(
VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", " VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
<< max_block_size << "]"; << max_block_size << "]";
for (auto* start_item = instruction_list.first(); start_item != nullptr; for (auto* start_item = instruction_list.first_skip_node();
start_item = instruction_list.next(start_item)) { start_item != nullptr;
start_item = instruction_list.next_skip_node(start_item)) {
std::vector<Item*> block = std::vector<Item*> block =
GetInitialBlock(instruction_list, *this, start_item, min_block_size); GetInitialBlock(instruction_list, *this, start_item, min_block_size);
if (block.size() < min_block_size) { if (block.size() < min_block_size) {
@ -1566,7 +1660,7 @@ StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
StatusOr<bool> HloRematerialization::RematerializeComputation( StatusOr<bool> HloRematerialization::RematerializeComputation(
HloComputation* computation, HloSchedule* schedule, HloComputation* computation, HloSchedule* schedule,
int64 memory_limit_bytes) { int64 memory_limit_bytes, int64 min_remat_size) {
VLOG(1) << "Rematerializing computation " << computation->name() VLOG(1) << "Rematerializing computation " << computation->name()
<< " with limit " << HumanReadableNumBytes(memory_limit_bytes); << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
VLOG(1) << "peak memory usage is " VLOG(1) << "peak memory usage is "
@ -1577,6 +1671,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
MemoryUsageTracker memory_tracker( MemoryUsageTracker memory_tracker(
computation, size_function_, compact_shape_function_, computation, size_function_, compact_shape_function_,
*points_to_analysis_, instruction_list, mode_); *points_to_analysis_, instruction_list, mode_);
instruction_list.PromoteNodesToSkip([&](Item* item) {
return memory_tracker.AllocatedSize(item) >= min_remat_size;
});
bool changed = false; bool changed = false;
// If the rematerialization makes the source instruction dead, then the // If the rematerialization makes the source instruction dead, then the
@ -1622,19 +1720,22 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// single instruction rematerialization is considered first. // single instruction rematerialization is considered first.
int min_block_size = 1; int min_block_size = 1;
int max_block_size = 1; int max_block_size = 1;
// Only trigger rematerialization when the memory usage changes.
while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { if (memory_tracker.AllocatedSize(item) + callee_usage > 0) {
while (memory_tracker.memory_usage() + callee_usage >
memory_limit_bytes) {
VLOG(2) << "Over memory limit at instruction " << instruction->name() VLOG(2) << "Over memory limit at instruction " << instruction->name()
<< ", using " << ", using "
<< HumanReadableNumBytes(memory_tracker.memory_usage() + << HumanReadableNumBytes(memory_tracker.memory_usage() +
callee_usage) callee_usage)
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes); << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
TF_ASSIGN_OR_RETURN(InstructionsAdded instructions_added, TF_ASSIGN_OR_RETURN(
RematerializeBestBlock( InstructionsAdded instructions_added,
min_block_size, max_block_size, &memory_tracker, RematerializeBestBlock(min_block_size, max_block_size,
&instruction_list, memory_limit_bytes, &memory_tracker, &instruction_list,
&rematerializable_map, &remat_move_instructions)); memory_limit_bytes, &rematerializable_map,
&remat_move_instructions));
net_instructions_added += instructions_added.net_instructions_added; net_instructions_added += instructions_added.net_instructions_added;
remat_count += instructions_added.remat_count; remat_count += instructions_added.remat_count;
@ -1658,7 +1759,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
break; break;
} }
} }
}
const CallSite* callsite = call_graph_node.GetCallSite(instruction); const CallSite* callsite = call_graph_node.GetCallSite(instruction);
if (callsite != nullptr && if (callsite != nullptr &&
callsite->context() == CallContext::kSequential && callsite->context() == CallContext::kSequential &&
@ -1683,10 +1784,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
bool subcomputation_changed, bool subcomputation_changed,
RematerializeComputation(called_computation, schedule, RematerializeComputation(called_computation, schedule,
subcomputation_memory_limit_bytes)); subcomputation_memory_limit_bytes,
min_remat_size));
changed |= subcomputation_changed; changed |= subcomputation_changed;
} }
} }
TF_ASSIGN_OR_RETURN(callee_usage, TF_ASSIGN_OR_RETURN(callee_usage,
CalledComputationsMemoryUsage(instruction)); CalledComputationsMemoryUsage(instruction));
} }
@ -1786,14 +1889,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
module_output_size; module_output_size;
VLOG(1) << "Peak memory usage of module (before): " VLOG(1) << "Peak memory usage of module (before): "
<< HumanReadableNumBytes(before_peak_memory); << HumanReadableNumBytes(before_peak_memory);
// Subcomputations called by the entry computation will also be // Subcomputations called by the entry computation will also be
// rematerialized. // rematerialized.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
bool changed, bool changed,
RematerializeComputation(module->entry_computation(), &module->schedule(), RematerializeComputation(module->entry_computation(), &module->schedule(),
adjusted_memory_limit_bytes)); adjusted_memory_limit_bytes, min_remat_size_));
// Rematerialization can introduce dead code. This occurs if all uses of an // Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction. // instruction are replaced with rematerializations of the instruction.
@ -1838,7 +1939,6 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
HumanReadableNumBytes(current_peak_memory), current_peak_memory); HumanReadableNumBytes(current_peak_memory), current_peak_memory);
} }
return changed; return changed;
} }

View File

@ -85,7 +85,8 @@ class HloRematerialization : public HloModulePass {
RematerializationSizes* sizes, RematerializationPass pass_location, RematerializationSizes* sizes, RematerializationPass pass_location,
int block_size_limit, int block_size_limit,
CompactShapeFunction compact_shape_function = nullptr, CompactShapeFunction compact_shape_function = nullptr,
RematerializationMode mode = RematerializationMode::kRecomputeAndCompress) RematerializationMode mode = RematerializationMode::kRecomputeAndCompress,
int64 min_remat_size = 0)
: size_function_(size_function), : size_function_(size_function),
memory_limit_bytes_(memory_limit_bytes), memory_limit_bytes_(memory_limit_bytes),
sizes_(sizes), sizes_(sizes),
@ -94,7 +95,8 @@ class HloRematerialization : public HloModulePass {
compact_shape_function_(compact_shape_function == nullptr compact_shape_function_(compact_shape_function == nullptr
? DefaultCompactShapeFunction ? DefaultCompactShapeFunction
: std::move(compact_shape_function)), : std::move(compact_shape_function)),
mode_(mode) {} mode_(mode),
min_remat_size_(min_remat_size) {}
~HloRematerialization() override = default; ~HloRematerialization() override = default;
absl::string_view name() const override { return "rematerialization"; } absl::string_view name() const override { return "rematerialization"; }
@ -114,7 +116,8 @@ class HloRematerialization : public HloModulePass {
// and inserted into 'order'. // and inserted into 'order'.
virtual StatusOr<bool> RematerializeComputation(HloComputation* computation, virtual StatusOr<bool> RematerializeComputation(HloComputation* computation,
HloSchedule* schedule, HloSchedule* schedule,
int64 memory_limit_bytes); int64 memory_limit_bytes,
int64 min_remat_size);
// Computes and returns the peak memory used by the given computation. The // Computes and returns the peak memory used by the given computation. The
// peak memory is the maximum total size of all live HLO instruction values at // peak memory is the maximum total size of all live HLO instruction values at
@ -185,6 +188,8 @@ class HloRematerialization : public HloModulePass {
int max_rematerialized_block_size_ = 0; int max_rematerialized_block_size_ = 0;
RematerializationMode mode_; RematerializationMode mode_;
int64 min_remat_size_;
}; };
} // namespace xla } // namespace xla

View File

@ -41,7 +41,8 @@ using ::testing::_;
class HloRematerializationTest : public RematerializationTestBase { class HloRematerializationTest : public RematerializationTestBase {
protected: protected:
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes, StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module) { HloModule* module,
int64 min_remat_size = 0) {
TF_EXPECT_OK(verifier().Run(module).status()); TF_EXPECT_OK(verifier().Run(module).status());
HloMemoryScheduler scheduler( HloMemoryScheduler scheduler(
[](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
@ -51,7 +52,9 @@ class HloRematerializationTest : public RematerializationTestBase {
ByteSizeOf, memory_limit_bytes, ByteSizeOf, memory_limit_bytes,
/*sizes=*/nullptr, /*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion, HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1); /*block_size_limit=*/1, nullptr,
HloRematerialization::RematerializationMode::kRecomputeAndCompress,
min_remat_size);
return remat.Run(module); return remat.Run(module);
} }
}; };
@ -96,6 +99,26 @@ TEST_F(HloRematerializationTest, SingleComputation) {
remat_bcast); remat_bcast);
} }
// Test rematerialization of a single computation that contains nodes that
// doesn't contain node worth using remat.
TEST_F(HloRematerializationTest, SingleComputationNoWorthRemat) {
auto module = CreateNewVerifiedModule();
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
// Set the minimum remat size to 14KiB, meaning no nodes should be remat.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024, module.get(),
/*min_remat_size=*/14 * 1024));
EXPECT_FALSE(changed);
}
// Test rematerialization of a single computation produced by // Test rematerialization of a single computation produced by
// MakeRematerializableComputation but with a sufficiently high memory limit // MakeRematerializableComputation but with a sufficiently high memory limit
// such that no instructions are rematerialized. // such that no instructions are rematerialized.
@ -577,17 +600,67 @@ class CompressingRematerializationTest : public RematerializationTestBase {
} }
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes, StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
HloModule* module) { HloModule* module,
int64 min_remat_size = 0) {
TF_EXPECT_OK(verifier().Run(module).status()); TF_EXPECT_OK(verifier().Run(module).status());
HloRematerialization remat( HloRematerialization remat(
ShapeSizePadMinorTo64, memory_limit_bytes, ShapeSizePadMinorTo64, memory_limit_bytes,
/*sizes=*/nullptr, /*sizes=*/nullptr,
HloRematerialization::RematerializationPass::kPreFusion, HloRematerialization::RematerializationPass::kPreFusion,
/*block_size_limit=*/1, ChooseCompactLayoutForShape); /*block_size_limit=*/1, ChooseCompactLayoutForShape,
HloRematerialization::RematerializationMode::kCompressOnly,
min_remat_size);
return remat.Run(module); return remat.Run(module);
} }
}; };
// Test rematerialization only remats big buffer that pass certain limits.
TEST_F(CompressingRematerializationTest, OnlyRematBigBuffer) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
%add_float {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param.0 = f32[] parameter(0)
%constant = f32[] constant(0)
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%broadcast.1 = f32[10,2]{1,0} broadcast(f32[] %param.0), dimensions={}
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%reduce.2 = f32[] reduce(f32[10,2]{1,0} %broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
ROOT %add.2 = f32[] add(f32[] %add, f32[] %reduce.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
// Only rematerialize buffers which have shaep f32[64, 2]. Buffers with shape
// f32[10, 2] are ignored.
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/30 * 1024,
module.get(), 10 * 1024));
EXPECT_TRUE(changed);
HloInstruction* broadcast =
module->entry_computation()->GetInstructionWithName("broadcast.0");
HloInstruction* broadcast_2 =
module->entry_computation()->GetInstructionWithName("broadcast.1");
HloInstruction* reduce =
module->entry_computation()->GetInstructionWithName("reduce.1");
HloInstruction* reduce_2 =
module->entry_computation()->GetInstructionWithName("reduce.2");
EXPECT_THAT(reduce,
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
EXPECT_THAT(reduce_2, op::Reduce(broadcast_2, op::Constant()));
}
// Test rematerialization of a single instruction. // Test rematerialization of a single instruction.
TEST_F(CompressingRematerializationTest, SingleRemat) { TEST_F(CompressingRematerializationTest, SingleRemat) {
const string& hlo_string = R"( const string& hlo_string = R"(