[TF:XLA] Remove a const_cast from HLO rematerialization.

This required making HloSchedule and related classes a sequence of HloInstruction* (instead of const HloInstruction*)

Also, use HloInstructionSequence in a few more places (hlo_rematerialization.cc)

No functional change.

PiperOrigin-RevId: 221868639
This commit is contained in:
A. Unique TensorFlower 2018-11-16 16:06:20 -08:00 committed by TensorFlower Gardener
parent 8930d5aff5
commit 2121365e74
23 changed files with 142 additions and 151 deletions

View File

@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloTestBase {
} }
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence( std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
HloModule* module, HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
absl::Span<const HloInstruction* const> instruction_sequence,
int64 alignment = 1) { int64 alignment = 1) {
HloSchedule schedule(module); HloSchedule schedule(module);
schedule.set_sequence(module->entry_computation(), instruction_sequence); schedule.set_sequence(module->entry_computation(), instruction_sequence);
@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module, std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) { int64 alignment = 1) {
HloSchedule schedule = HloSchedule schedule =
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run( return BufferAssigner::Run(
module, absl::make_unique<SequentialHloOrdering>(schedule), module, absl::make_unique<SequentialHloOrdering>(schedule),
ByteSizeOf, ByteSizeOf,
@ -2162,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// nodes are traversed during BufferAssignment. // nodes are traversed during BufferAssignment.
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), return ShapeUtil::ByteSizeOf(buffer.shape(),
/*pointer_size=*/sizeof(void*)); /*pointer_size=*/sizeof(void*));
})); }));
@ -2391,15 +2390,16 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
RunCopyInsertion(module.get()); RunCopyInsertion(module.get());
HloSchedule schedule = HloSchedule schedule =
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo schedule for the // To trigger b/38494731, we want a specific Hlo schedule for the
// root computation, so we overwrite that entry with a manually // root computation, so we overwrite that entry with a manually
// crafted sequence. // crafted sequence.
schedule.set_sequence(module->entry_computation(), schedule.set_sequence(
{input1, weights1, one, output1, while1->operand(0), module->entry_computation(),
while1, input0, weights0, zero, output0, {input1, weights1, one, output1, while1->mutable_operand(0), while1,
while0->operand(0), while0, gte0, gte1, root_add}); input0, weights0, zero, output0, while0->mutable_operand(0), while0,
gte0, gte1, root_add});
// If this ASSERT fails, we constructed a bogus sequence above and this test // If this ASSERT fails, we constructed a bogus sequence above and this test
// itself is buggy. // itself is buggy.

View File

@ -587,9 +587,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// Select an order for emitting the HLO instructions for each // Select an order for emitting the HLO instructions for each
// computation. Using this sequence enables tighter buffer liveness analysis // computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering). // and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(HloSchedule schedule,
HloSchedule schedule, ScheduleModule(module.get(), BufferSizeBytesFunction(),
ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); DFSMemoryScheduler));
// Run buffer allocation on the HLO graph. // Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -779,7 +779,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
XLA_VLOG_LINES(2, module->ToString()); XLA_VLOG_LINES(2, module->ToString());
TF_ASSIGN_OR_RETURN(HloSchedule schedule, TF_ASSIGN_OR_RETURN(HloSchedule schedule,
ScheduleModule(*module, BufferSizeBytesFunction())); ScheduleModule(module, BufferSizeBytesFunction()));
// Run buffer analysis on the HLO graph. This analysis figures out which // Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation. // temporary buffers are required to run the computation.

View File

@ -111,7 +111,7 @@ IrEmitter::IrEmitter(
StatusOr<llvm::Function*> IrEmitter::EmitComputation( StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix, HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation, bool is_top_level_computation,
const std::vector<const HloInstruction*>* instruction_order) { const std::vector<HloInstruction*>* instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix); string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr); << "]; ordered? " << (instruction_order != nullptr);

View File

@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Function*> EmitComputation( StatusOr<llvm::Function*> EmitComputation(
HloComputation* computation, const string& function_name_prefix, HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation, bool is_top_level_computation,
const std::vector<const HloInstruction*>* instruction_order); const std::vector<HloInstruction*>* instruction_order);
llvm::IRBuilder<>* b() { return &b_; } llvm::IRBuilder<>* b() { return &b_; }

View File

@ -37,7 +37,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
public: public:
GpuHloOrdering(const HloModule* module, GpuHloOrdering(const HloModule* module,
const StreamAssignment& stream_assignment, const StreamAssignment& stream_assignment,
const std::vector<const HloInstruction*>& thunk_launch_order); const std::vector<HloInstruction*>& thunk_launch_order);
~GpuHloOrdering() override = default; ~GpuHloOrdering() override = default;
// Only the entry computation can possibly be sequentially ordered, and only // Only the entry computation can possibly be sequentially ordered, and only
@ -56,7 +56,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
GpuHloOrdering::GpuHloOrdering( GpuHloOrdering::GpuHloOrdering(
const HloModule* module, const StreamAssignment& stream_assignment, const HloModule* module, const StreamAssignment& stream_assignment,
const std::vector<const HloInstruction*>& thunk_launch_order) const std::vector<HloInstruction*>& thunk_launch_order)
: PredecessorHloOrdering(module) { : PredecessorHloOrdering(module) {
// The entry computation has a total order when there's only one stream. // The entry computation has a total order when there's only one stream.
if (stream_assignment.StreamCount() == 1) { if (stream_assignment.StreamCount() == 1) {
@ -150,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering(
// However, if the total order is A,B,D,C,E, then C and E can run // However, if the total order is A,B,D,C,E, then C and E can run
// concurrently. // concurrently.
void BFSLaunchOrder(const HloComputation* computation, void BFSLaunchOrder(const HloComputation* computation,
std::vector<const HloInstruction*>* launch_order) { std::vector<HloInstruction*>* launch_order) {
// This topological sort uses two data structures: // This topological sort uses two data structures:
// 1. `incoming_edge_count` which keeps track of the number of incoming // 1. `incoming_edge_count` which keeps track of the number of incoming
// edges to each HLO; // edges to each HLO;
@ -158,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation,
// //
// The sorting algorithm repeatedly pops the top from the queue and deletes // The sorting algorithm repeatedly pops the top from the queue and deletes
// that HLO from the graph, making more HLOs incoming-edge free. // that HLO from the graph, making more HLOs incoming-edge free.
std::deque<const HloInstruction*> queue; std::deque<HloInstruction*> queue;
std::unordered_map<const HloInstruction*, int64> incoming_edge_count; std::unordered_map<const HloInstruction*, int64> incoming_edge_count;
for (const auto& hlo : computation->instructions()) { for (auto* hlo : computation->instructions()) {
if (hlo->operand_count() == 0) { if (hlo->operand_count() == 0) {
queue.push_back(hlo); queue.push_back(hlo);
} else { } else {
@ -172,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation,
} }
while (!queue.empty()) { while (!queue.empty()) {
const HloInstruction* x = queue.front(); HloInstruction* x = queue.front();
queue.pop_front(); queue.pop_front();
launch_order->push_back(x); launch_order->push_back(x);
for (const HloInstruction* y : x->users()) { for (HloInstruction* y : x->users()) {
--incoming_edge_count[y]; --incoming_edge_count[y];
if (incoming_edge_count[y] == 0) { if (incoming_edge_count[y] == 0) {
queue.push_back(y); queue.push_back(y);
@ -195,14 +195,14 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule); std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
// Initialize thunk_launch_order_, the total order of thunk launches. // Initialize thunk_launch_order_, the total order of thunk launches.
const HloComputation* entry_computation = module.entry_computation(); HloComputation* entry_computation = module.entry_computation();
if (stream_assignment.StreamCount() == 1) { if (stream_assignment.StreamCount() == 1) {
// All kernels are launched on a single stream, so there's no loss of // All kernels are launched on a single stream, so there's no loss of
// concurrency by optimizing for minimal memory usage. // concurrency by optimizing for minimal memory usage.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloInstructionSequence sequence, HloInstructionSequence sequence,
ScheduleComputation( ScheduleComputation(
*entry_computation, [pointer_size](const BufferValue& buffer) { entry_computation, [pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
})); }));
schedule->thunk_launch_order_ = sequence.instructions(); schedule->thunk_launch_order_ = sequence.instructions();

View File

@ -46,7 +46,7 @@ class GpuHloSchedule {
// Returns the total order of thunk launches, represented in terms of HLO // Returns the total order of thunk launches, represented in terms of HLO
// instructions. // instructions.
const std::vector<const HloInstruction*>& ThunkLaunchOrder() const { const std::vector<HloInstruction*>& ThunkLaunchOrder() const {
return thunk_launch_order_; return thunk_launch_order_;
} }
@ -60,7 +60,7 @@ class GpuHloSchedule {
private: private:
GpuHloSchedule(); GpuHloSchedule();
std::vector<const HloInstruction*> thunk_launch_order_; std::vector<HloInstruction*> thunk_launch_order_;
std::unique_ptr<HloOrdering> hlo_ordering_; std::unique_ptr<HloOrdering> hlo_ordering_;
}; };

View File

@ -33,7 +33,7 @@ namespace gpu {
class GpuHloScheduleTest : public HloTestBase { class GpuHloScheduleTest : public HloTestBase {
protected: protected:
using HloVec = std::vector<const HloInstruction*>; using HloVec = std::vector<HloInstruction*>;
// Pre-canned shapes. // Pre-canned shapes.
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});

View File

@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands(
ThunkSchedule::ThunkSchedule( ThunkSchedule::ThunkSchedule(
std::unique_ptr<ThunkSequence> thunks, std::unique_ptr<ThunkSequence> thunks,
std::unique_ptr<StreamAssignment> stream_assignment, std::unique_ptr<StreamAssignment> stream_assignment,
const std::vector<const HloInstruction*>& hlo_total_order) const std::vector<HloInstruction*>& hlo_total_order)
: thunks_(std::move(thunks)), : thunks_(std::move(thunks)),
stream_assignment_(std::move(stream_assignment)) { stream_assignment_(std::move(stream_assignment)) {
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk; std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule(
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
} }
for (const HloInstruction* hlo : hlo_total_order) { for (HloInstruction* hlo : hlo_total_order) {
if (hlo_to_thunk.count(hlo)) { if (hlo_to_thunk.count(hlo)) {
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
} }

View File

@ -46,7 +46,7 @@ class ThunkSchedule {
public: public:
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks, ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
std::unique_ptr<StreamAssignment> stream_assignment, std::unique_ptr<StreamAssignment> stream_assignment,
const std::vector<const HloInstruction*>& hlo_total_order); const std::vector<HloInstruction*>& hlo_total_order);
// Returns the total order of executing all the thunks. // Returns the total order of executing all the thunks.
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; } const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }

View File

@ -258,7 +258,7 @@ class HeapSimulatorTracker {
// Constructor for testing a single entry computation. // Constructor for testing a single entry computation.
HeapSimulatorTracker( HeapSimulatorTracker(
const string& name, std::unique_ptr<HloComputation> computation, const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) { const std::vector<HloInstruction*>& instruction_sequence) {
HloModuleConfig config; HloModuleConfig config;
module_ = absl::make_unique<HloModule>(name, config); module_ = absl::make_unique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation)); module_->AddEntryComputation(std::move(computation));
@ -286,7 +286,7 @@ class HeapSimulatorTracker {
// Similar to the single entry computation constructor above, but runs the // Similar to the single entry computation constructor above, but runs the
// simulation over the entire module. // simulation over the entire module.
void RunWholeModule( void RunWholeModule(
const std::vector<const HloInstruction*>& full_module_sequence) { const std::vector<HloInstruction*>& full_module_sequence) {
points_to_analysis_ = points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@ -294,7 +294,7 @@ class HeapSimulatorTracker {
HloSchedule schedule(module_.get()); HloSchedule schedule(module_.get());
absl::flat_hash_map<const HloInstruction*, int> reverse_position; absl::flat_hash_map<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) { for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i]; HloInstruction* instruction = full_module_sequence[i];
schedule.GetOrCreateSequence(instruction->parent()) schedule.GetOrCreateSequence(instruction->parent())
.push_back(instruction); .push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i; reverse_position[instruction] = full_module_sequence.size() - i;

View File

@ -795,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder(
template <typename HloInstructionPtr> template <typename HloInstructionPtr>
Status HloComputation::AcceptOrdered( Status HloComputation::AcceptOrdered(
DfsHloVisitorBase<HloInstructionPtr>* visitor, DfsHloVisitorBase<HloInstructionPtr>* visitor,
const std::vector<const HloInstruction*>& order) const { const std::vector<HloInstruction*>& order) const {
VLOG(3) << "Accepting visitor with order."; VLOG(3) << "Accepting visitor with order.";
for (HloInstruction* root : CollectUnreachableRoots()) { for (HloInstruction* root : CollectUnreachableRoots()) {
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
@ -825,9 +825,9 @@ Status HloComputation::AcceptOrdered(
// Explicit instantiations. // Explicit instantiations.
template Status HloComputation::AcceptOrdered( template Status HloComputation::AcceptOrdered(
DfsHloVisitor*, const std::vector<const HloInstruction*>&) const; DfsHloVisitor*, const std::vector<HloInstruction*>&) const;
template Status HloComputation::AcceptOrdered( template Status HloComputation::AcceptOrdered(
ConstDfsHloVisitor*, const std::vector<const HloInstruction*>&) const; ConstDfsHloVisitor*, const std::vector<HloInstruction*>&) const;
Status HloComputation::Accept( Status HloComputation::Accept(
const std::function<Status(HloInstruction*)>& visitor_func) { const std::function<Status(HloInstruction*)>& visitor_func) {

View File

@ -301,7 +301,7 @@ class HloComputation {
// be a topological sort of all instructions in the computation. // be a topological sort of all instructions in the computation.
template <typename HloInstructionPtr> template <typename HloInstructionPtr>
Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor, Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
const std::vector<const HloInstruction*>& order) const; const std::vector<HloInstruction*>& order) const;
// Same as Accept() above, but the visitor is given as a function. // Same as Accept() above, but the visitor is given as a function.
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func); Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);

View File

@ -73,7 +73,7 @@ class ListScheduler {
// Construct and return a memory-minimizing sequence of HLO instructions // Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation. // containing the given HLO computation.
static StatusOr<HloInstructionSequence> Run( static StatusOr<HloInstructionSequence> Run(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -98,7 +98,7 @@ class ListScheduler {
// comparison operators. // comparison operators.
using Priority = std::pair<int64, int64>; using Priority = std::pair<int64, int64>;
ListScheduler(const HloComputation& computation, ListScheduler(HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -111,7 +111,7 @@ class ListScheduler {
// instruction. An HLO instruction "uses" a LogicalBuffer if the // instruction. An HLO instruction "uses" a LogicalBuffer if the
// LogicalBuffer is in an operand of the instruction as indicated by // LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis. // points-to analysis.
for (auto* instruction : computation.instructions()) { for (auto* instruction : computation->instructions()) {
absl::flat_hash_set<const LogicalBuffer*> instr_uses; absl::flat_hash_set<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) { for (auto* operand : instruction->operands()) {
points_to_analysis.GetPointsToSet(operand).ForEachElement( points_to_analysis.GetPointsToSet(operand).ForEachElement(
@ -126,13 +126,13 @@ class ListScheduler {
// Create map containing the number of unscheduled uses (hlo instructions) // Create map containing the number of unscheduled uses (hlo instructions)
// of each logical buffer. // of each logical buffer.
for (auto* instruction : computation.instructions()) { for (auto* instruction : computation->instructions()) {
for (auto* buffer : for (auto* buffer :
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
unscheduled_use_count_[buffer] = 0; unscheduled_use_count_[buffer] = 0;
} }
} }
for (auto* instruction : computation.instructions()) { for (auto* instruction : computation->instructions()) {
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
++unscheduled_use_count_[buffer]; ++unscheduled_use_count_[buffer];
} }
@ -141,7 +141,7 @@ class ListScheduler {
// Buffers live out of the computation have an implicit use at the end of // Buffers live out of the computation have an implicit use at the end of
// the computation. // the computation.
for (const LogicalBuffer* live_out_buffer : for (const LogicalBuffer* live_out_buffer :
points_to_analysis.GetPointsToSet(computation.root_instruction()) points_to_analysis.GetPointsToSet(computation->root_instruction())
.CreateFlattenedSet()) { .CreateFlattenedSet()) {
++unscheduled_use_count_[live_out_buffer]; ++unscheduled_use_count_[live_out_buffer];
} }
@ -157,7 +157,7 @@ class ListScheduler {
// HloInstruction, plus some cached metadata, saved for the purposes of making // HloInstruction, plus some cached metadata, saved for the purposes of making
// BytesFreedIfScheduled fast. // BytesFreedIfScheduled fast.
struct ReadyListEntry { struct ReadyListEntry {
const HloInstruction* instruction; HloInstruction* instruction;
// The total size of all buffers defined by this instruction. // The total size of all buffers defined by this instruction.
int64 bytes_defined; int64 bytes_defined;
@ -171,7 +171,7 @@ class ListScheduler {
}; };
// Creates a ReadyListEntry for the given instruction. // Creates a ReadyListEntry for the given instruction.
ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
ReadyListEntry entry; ReadyListEntry entry;
entry.instruction = instruction; entry.instruction = instruction;
@ -250,13 +250,13 @@ class ListScheduler {
// Populate the ready list with instructions which have no operands or // Populate the ready list with instructions which have no operands or
// control predecessors. // control predecessors.
absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count; absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
for (auto* instruction : computation_.instructions()) { for (auto* instruction : computation_->instructions()) {
// TODO(b/34466113): Replace this and above with successors() or // TODO(b/34466113): Replace this and above with successors() or
// predecessors() when these methods are added to HloInstruction. // predecessors() when these methods are added to HloInstruction.
for (const HloInstruction* user : instruction->users()) { for (HloInstruction* user : instruction->users()) {
unscheduled_pred_count[user]++; unscheduled_pred_count[user]++;
} }
for (const HloInstruction* succ : instruction->control_successors()) { for (HloInstruction* succ : instruction->control_successors()) {
unscheduled_pred_count[succ]++; unscheduled_pred_count[succ]++;
} }
} }
@ -275,7 +275,7 @@ class ListScheduler {
ready_instructions[inst] = it; ready_instructions[inst] = it;
}; };
for (auto* instruction : computation_.instructions()) { for (auto* instruction : computation_->instructions()) {
if (instruction->operands().empty() && if (instruction->operands().empty() &&
instruction->control_predecessors().empty()) { instruction->control_predecessors().empty()) {
add_to_ready_queue(instruction); add_to_ready_queue(instruction);
@ -287,7 +287,7 @@ class ListScheduler {
// schedule. // schedule.
auto best_it = ready_queue.end(); auto best_it = ready_queue.end();
--best_it; --best_it;
const HloInstruction* best = best_it->second.instruction; HloInstruction* best = best_it->second.instruction;
VLOG(2) << "Schedule instruction: " << best->ToShortString() VLOG(2) << "Schedule instruction: " << best->ToShortString()
<< " Bytes freed: " << best_it->first.first; << " Bytes freed: " << best_it->first.first;
ready_queue.erase(best_it); ready_queue.erase(best_it);
@ -348,13 +348,13 @@ class ListScheduler {
} }
} }
} }
CHECK_EQ(schedule.size(), computation_.instruction_count()); CHECK_EQ(schedule.size(), computation_->instruction_count());
CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
return schedule; return schedule;
} }
const HloComputation& computation_; HloComputation* computation_;
const TuplePointsToAnalysis& points_to_analysis_; const TuplePointsToAnalysis& points_to_analysis_;
const LogicalBuffer::SizeFunction& size_function_; const LogicalBuffer::SizeFunction& size_function_;
// Computations are analyzed in post-order. When scheduling an instruction // Computations are analyzed in post-order. When scheduling an instruction
@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes(
} }
StatusOr<HloInstructionSequence> ScheduleComputationHelper( StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm, const MemorySchedulerAlgorithm& algorithm,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) { memory_by_computation) {
VLOG(2) << "Computation: " << computation.name(); VLOG(2) << "Computation: " << computation->name();
if (algorithm) { if (algorithm) {
return algorithm(computation, points_to_analysis, size_function, return algorithm(computation, points_to_analysis, size_function,
memory_by_computation); memory_by_computation);
@ -404,17 +404,17 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
} // namespace } // namespace
StatusOr<HloInstructionSequence> DFSMemoryScheduler( StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) { memory_by_computation) {
// These variables are a hack to prevent overflows. // These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0; int64 cumulative_total_size = 0;
int64 total_hlos = computation.parent()->instruction_count(); int64 total_hlos = computation->parent()->instruction_count();
absl::flat_hash_map<const HloInstruction*, int64> extra_users; absl::flat_hash_map<const HloInstruction*, int64> extra_users;
absl::flat_hash_map<const HloInstruction*, int64> total_sizes; absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
if (ListScheduler::IgnoreInstruction(*hlo)) { if (ListScheduler::IgnoreInstruction(*hlo)) {
extra_users[hlo] = 0; extra_users[hlo] = 0;
total_sizes[hlo] = 0; total_sizes[hlo] = 0;
@ -448,8 +448,8 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
extra_users[hlo] = std::min(extra_users[hlo], total_hlos); extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
} }
CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(extra_users.size(), computation->instruction_count());
CHECK_EQ(total_sizes.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation->instruction_count());
// Construct a total order based on DFS post-order, visiting operands in // Construct a total order based on DFS post-order, visiting operands in
// decreasing cumulative extra user order, and next by cumulative size, with a // decreasing cumulative extra user order, and next by cumulative size, with a
@ -459,7 +459,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
sequence.push_back(hlo); sequence.push_back(hlo);
return Status::OK(); return Status::OK();
}); });
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
&visitor, [&extra_users, &total_sizes](const HloInstruction* a, &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
const HloInstruction* b) { const HloInstruction* b) {
if (extra_users[a] != extra_users[b]) { if (extra_users[a] != extra_users[b]) {
@ -470,12 +470,12 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
} }
return a->name() < b->name(); return a->name() < b->name();
})); }));
CHECK_EQ(sequence.size(), computation.instruction_count()); CHECK_EQ(sequence.size(), computation->instruction_count());
return sequence; return sequence;
} // namespace xla } // namespace xla
StatusOr<HloInstructionSequence> ListMemoryScheduler( StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -485,16 +485,16 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
} }
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) { memory_by_computation) {
return HloInstructionSequence(computation.MakeInstructionPostOrder()); return HloInstructionSequence(computation->MakeInstructionPostOrder());
} }
StatusOr<HloInstructionSequence> DefaultMemoryScheduler( StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -513,7 +513,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
memory_by_computation)); memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 list_memory, TF_ASSIGN_OR_RETURN(const int64 list_memory,
HeapSimulator::MinimumMemoryForComputation( HeapSimulator::MinimumMemoryForComputation(
computation, list_sequence, points_to_analysis, *computation, list_sequence, points_to_analysis,
size_function, &memory_by_computation)); size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
@ -522,7 +522,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
size_function, memory_by_computation)); size_function, memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 dfs_memory, TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
HeapSimulator::MinimumMemoryForComputation( HeapSimulator::MinimumMemoryForComputation(
computation, dfs_sequence, points_to_analysis, *computation, dfs_sequence, points_to_analysis,
size_function, &memory_by_computation)); size_function, &memory_by_computation));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
@ -532,7 +532,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
memory_by_computation)); memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 post_order_memory, TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
HeapSimulator::MinimumMemoryForComputation( HeapSimulator::MinimumMemoryForComputation(
computation, post_order_sequence, points_to_analysis, *computation, post_order_sequence, points_to_analysis,
size_function, &memory_by_computation)); size_function, &memory_by_computation));
VLOG(2) << "Min-memory post order sequence: " VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory); << HumanReadableNumBytes(post_order_memory);
@ -555,17 +555,17 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
} }
StatusOr<HloSchedule> ScheduleModule( StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function, HloModule* module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm) { const MemorySchedulerAlgorithm& algorithm) {
HloSchedule schedule(&module); HloSchedule schedule(module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module)); TuplePointsToAnalysis::Run(module));
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation; absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) { for (auto* computation : module->MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) { if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
ScheduleComputationHelper( ScheduleComputationHelper(
*computation, *points_to_analysis, size_function, computation, *points_to_analysis, size_function,
algorithm, memory_by_computation)); algorithm, memory_by_computation));
memory_by_computation[computation] = memory_by_computation[computation] =
HeapSimulator::MinimumMemoryForComputation( HeapSimulator::MinimumMemoryForComputation(
@ -583,11 +583,11 @@ StatusOr<HloSchedule> ScheduleModule(
} }
StatusOr<HloInstructionSequence> ScheduleComputation( StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation, HloComputation* computation,
const LogicalBuffer::SizeFunction& size_function) { const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation()); CHECK(!computation->IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent())); TuplePointsToAnalysis::Run(computation->parent()));
absl::flat_hash_map<const HloComputation*, int64> empty_map; absl::flat_hash_map<const HloComputation*, int64> empty_map;
return ScheduleComputationHelper(computation, *points_to_analysis, return ScheduleComputationHelper(computation, *points_to_analysis,
size_function, nullptr, empty_map); size_function, nullptr, empty_map);
@ -600,7 +600,7 @@ HloMemoryScheduler::HloMemoryScheduler(
StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) { StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(HloSchedule schedule, TF_ASSIGN_OR_RETURN(HloSchedule schedule,
ScheduleModule(*module, size_function_, algorithm_)); ScheduleModule(module, size_function_, algorithm_));
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
return true; return true;
} }

View File

@ -36,14 +36,14 @@ namespace xla {
// that describes buffer aliasing, together with a target-specific size function // that describes buffer aliasing, together with a target-specific size function
// that maps a tensor's logical size to its padded size. // that maps a tensor's logical size to its padded size.
typedef std::function<StatusOr<HloInstructionSequence>( typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&, HloComputation*, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&, const LogicalBuffer::SizeFunction&,
const absl::flat_hash_map<const HloComputation*, int64>&)> const absl::flat_hash_map<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm; MemorySchedulerAlgorithm;
// List scheduler // List scheduler
StatusOr<HloInstructionSequence> ListMemoryScheduler( StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -51,7 +51,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
// DFS-order scheduler // DFS-order scheduler
StatusOr<HloInstructionSequence> DFSMemoryScheduler( StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -59,7 +59,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
// Naive Post Order scheduler // Naive Post Order scheduler
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -69,7 +69,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
// and the DFS scheduler, and chooses whichever returns a lower min-memory, // and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation. // not accounting for fragmentation.
StatusOr<HloInstructionSequence> DefaultMemoryScheduler( StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis, const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function, const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>& const absl::flat_hash_map<const HloComputation*, int64>&
@ -79,13 +79,13 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
// the computation. size_function is the function returning the number of bytes // the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer. // required for a LogicalBuffer.
StatusOr<HloSchedule> ScheduleModule( StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function, HloModule* module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {}); const MemorySchedulerAlgorithm& algorithm = {});
// Computes the schedule for a single computation. // Computes the schedule for a single computation.
// Currently only used by the GPU backend. // Currently only used by the GPU backend.
StatusOr<HloInstructionSequence> ScheduleComputation( StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation, HloComputation* computation,
const LogicalBuffer::SizeFunction& size_function); const LogicalBuffer::SizeFunction& size_function);
// A pass which schedules the HLO instructions in a module. The HloModule's // A pass which schedules the HLO instructions in a module. The HloModule's

View File

@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
TF_ASSERT_OK(module->schedule().Verify()); TF_ASSERT_OK(module->schedule().Verify());
// Verify that all instructions are in the sequence. // Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence = const std::vector<HloInstruction*>& sequence =
module->schedule().sequence(module->entry_computation()).instructions(); module->schedule().sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
@ -124,9 +124,9 @@ ENTRY root {
}; };
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, size_fn, ListMemoryScheduler)); ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence. // Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence = const std::vector<HloInstruction*>& sequence =
schedule.sequence(module->entry_computation()).instructions(); schedule.sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
@ -175,12 +175,13 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto module = CreateNewVerifiedModule(); auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build()); module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(*module, ScheduleModule(
[](const BufferValue& buffer) { module.get(),
return ShapeUtil::ByteSizeOf( [](const BufferValue& buffer) {
buffer.shape(), TUPLE_SIZE); return ShapeUtil::ByteSizeOf(buffer.shape(),
}, TUPLE_SIZE);
ListMemoryScheduler)); },
ListMemoryScheduler));
// Verify that all instructions are in the sequence. // Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(), EXPECT_EQ(module->entry_computation()->instruction_count(),
@ -225,12 +226,12 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
{tuple, mul, add}, HloInstruction::FusionKind::kLoop); {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(*module, ScheduleModule(
[](const BufferValue& buffer) { module.get(),
return ShapeUtil::ByteSizeOf( [](const BufferValue& buffer) {
buffer.shape(), 2); return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
}, },
ListMemoryScheduler)); ListMemoryScheduler));
// Verify that all instructions are in the sequence. // Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(), EXPECT_EQ(module->entry_computation()->instruction_count(),
@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
}; };
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, size_fn, ListMemoryScheduler)); ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence. // Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation(); auto entry_computation = module->entry_computation();
EXPECT_EQ(module->entry_computation()->instruction_count(), EXPECT_EQ(module->entry_computation()->instruction_count(),

View File

@ -104,11 +104,7 @@ class HloModule {
HloCloneContext* context = nullptr); HloCloneContext* context = nullptr);
// Return a pointer to the entry computation of the module. // Return a pointer to the entry computation of the module.
const HloComputation* entry_computation() const { HloComputation* entry_computation() const {
CHECK_NE(nullptr, entry_computation_);
return entry_computation_;
}
HloComputation* entry_computation() {
CHECK_NE(nullptr, entry_computation_); CHECK_NE(nullptr, entry_computation_);
return entry_computation_; return entry_computation_;
} }

View File

@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() {
// Create a map from instruction to its order position. // Create a map from instruction to its order position.
TF_DCHECK_OK(schedule_.Verify()); TF_DCHECK_OK(schedule_.Verify());
for (const auto& computation_sequence : schedule_.sequences()) { for (const auto& computation_sequence : schedule_.sequences()) {
const std::vector<const HloInstruction*>& order = const auto& order = computation_sequence.second.instructions();
computation_sequence.second.instructions();
for (int i = 0; i < order.size(); ++i) { for (int i = 0; i < order.size(); ++i) {
InsertOrDie(&order_position_, order[i], i); InsertOrDie(&order_position_, order[i], i);
} }

View File

@ -47,11 +47,11 @@ const double kF16max = 65504;
// Creates and returns a schedule created using the order of the instructions in // Creates and returns a schedule created using the order of the instructions in
// the HloComputation::instructions() vectors in the module. // the HloComputation::instructions() vectors in the module.
HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
HloSchedule schedule(module); HloSchedule schedule(module);
for (const HloComputation* computation : module->computations()) { for (HloComputation* computation : module->computations()) {
if (!computation->IsFusionComputation()) { if (!computation->IsFusionComputation()) {
for (const HloInstruction* instruction : computation->instructions()) { for (HloInstruction* instruction : computation->instructions()) {
schedule.GetOrCreateSequence(computation).push_back(instruction); schedule.GetOrCreateSequence(computation).push_back(instruction);
} }
} }

View File

@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector<Item*, 3>;
// before arbitrary elements. // before arbitrary elements.
class InstructionList { class InstructionList {
public: public:
explicit InstructionList(const std::vector<const HloInstruction*>& order) { explicit InstructionList(const HloInstructionSequence& order) {
int64 position = 0; int64 position = 0;
Item* last = nullptr; Item* last = nullptr;
for (const HloInstruction* inst : order) { for (HloInstruction* inst : order.instructions()) {
// Add a new item to the linked list. // Add a new item to the linked list.
Item* item = new Item; Item* item = new Item;
item->next = nullptr; item->next = nullptr;
@ -151,7 +151,7 @@ class InstructionList {
// to be monotonically increasing through the list, and so is still useful // to be monotonically increasing through the list, and so is still useful
// for quickly(-ish) determining the order of arbitrary instructions in // for quickly(-ish) determining the order of arbitrary instructions in
// the list. // the list.
item->instruction = const_cast<HloInstruction*>(inst); item->instruction = inst;
item->position = position; item->position = position;
position++; position++;
@ -927,7 +927,7 @@ Item* PickRematerializationCandidate(
StatusOr<int64> HloRematerialization::ComputePeakMemory( StatusOr<int64> HloRematerialization::ComputePeakMemory(
const HloComputation* computation, const HloComputation* computation,
const std::vector<const HloInstruction*>& order) const { const HloInstructionSequence& order) const {
InstructionList instruction_list(order); InstructionList instruction_list(order);
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
instruction_list); instruction_list);
@ -971,8 +971,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< HumanReadableNumBytes(computation_peak_memory_.at(computation)); << HumanReadableNumBytes(computation_peak_memory_.at(computation));
CHECK(!ContainsKey(rematerialized_computations_, computation)); CHECK(!ContainsKey(rematerialized_computations_, computation));
InstructionList instruction_list( InstructionList instruction_list(schedule->sequence(computation));
schedule->sequence(computation).instructions());
MemoryUsageTracker memory_tracker(computation, size_function_, MemoryUsageTracker memory_tracker(computation, size_function_,
*points_to_analysis_, instruction_list); *points_to_analysis_, instruction_list);
bool changed = false; bool changed = false;
@ -1184,7 +1183,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
sequence.clear(); sequence.clear();
for (auto* item = instruction_list.first(); item != nullptr; for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) { item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction; HloInstruction* instruction = item->instruction;
sequence.push_back(instruction); sequence.push_back(instruction);
} }
rematerialized_computations_.insert(computation); rematerialized_computations_.insert(computation);
@ -1235,10 +1234,8 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
if (node.context() == CallContext::kSequential) { if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()], computation_peak_memory_[node.computation()],
ComputePeakMemory(node.computation(), ComputePeakMemory(node.computation(), module->schedule().sequence(
module->schedule() node.computation())));
.sequence(node.computation())
.instructions()));
} }
return Status::OK(); return Status::OK();
}, },

View File

@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass {
// peak memory is the maximum total size of all live HLO instruction values at // peak memory is the maximum total size of all live HLO instruction values at
// any program point. 'order' is the order in which the HLO instructions will // any program point. 'order' is the order in which the HLO instructions will
// be emitted which is used to determine lifespans of HLO values. // be emitted which is used to determine lifespans of HLO values.
StatusOr<int64> ComputePeakMemory( StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
const HloComputation* computation, const HloInstructionSequence& order) const;
const std::vector<const HloInstruction*>& order) const;
// Returns the peak memory usage of the called computations for the given // Returns the peak memory usage of the called computations for the given
// instruction. Zero is returned if the instruction calls no computations. // instruction. Zero is returned if the instruction calls no computations.

View File

@ -46,8 +46,8 @@ namespace xla {
<< "No computation exists in HLO module with id " << computation_id; << "No computation exists in HLO module with id " << computation_id;
const HloComputation* computation = comp_it->second; const HloComputation* computation = comp_it->second;
absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction; absl::flat_hash_map<int64, HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) { for (HloInstruction* instruction : computation->instructions()) {
id_to_instruction[instruction->unique_id()] = instruction; id_to_instruction[instruction->unique_id()] = instruction;
} }
@ -81,9 +81,8 @@ StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
return std::move(proto); return std::move(proto);
} }
void HloSchedule::set_sequence( void HloSchedule::set_sequence(const HloComputation* computation,
const HloComputation* computation, absl::Span<HloInstruction* const> sequence) {
absl::Span<const HloInstruction* const> sequence) {
set_sequence(computation, HloInstructionSequence(sequence)); set_sequence(computation, HloInstructionSequence(sequence));
} }
@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule(
const HloComputation* computation) { const HloComputation* computation) {
// Map from unique ID to HloInstruction pointer for instructions in the // Map from unique ID to HloInstruction pointer for instructions in the
// computation. // computation.
absl::flat_hash_map<int, const HloInstruction*> id_to_instruction; absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) { for (HloInstruction* instruction : computation->instructions()) {
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
} }
@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule(
// Map from HloInstruction X to newly added instructions (instruction is in // Map from HloInstruction X to newly added instructions (instruction is in
// computation, but not in schedule) which use X. If an instruction is not in // computation, but not in schedule) which use X. If an instruction is not in
// the map, then it has no users which are newly added instructions. // the map, then it has no users which are newly added instructions.
absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>> absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
new_instruction_uses; new_instruction_uses;
// For each newly added instruction, this is the count of the instruction's // For each newly added instruction, this is the count of the instruction's
@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule(
// Create a worklist of newly added instructions which are ready to be added // Create a worklist of newly added instructions which are ready to be added
// to the schedule. Initialize worklist with those that have zero operands. // to the schedule. Initialize worklist with those that have zero operands.
std::queue<const HloInstruction*> worklist; std::queue<HloInstruction*> worklist;
for (const HloInstruction* instruction : computation->instructions()) { for (HloInstruction* instruction : computation->instructions()) {
if (ids_in_schedule.count(instruction->unique_id()) == 0) { if (ids_in_schedule.count(instruction->unique_id()) == 0) {
// This is a newly added instruction which is not in the schedule. // This is a newly added instruction which is not in the schedule.
if (instruction->operands().empty()) { if (instruction->operands().empty()) {
@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule(
// Lambda which schedules all instructions on the worklist. // Lambda which schedules all instructions on the worklist.
auto schedule_worklist = [&]() { auto schedule_worklist = [&]() {
while (!worklist.empty()) { while (!worklist.empty()) {
const HloInstruction* instruction = worklist.front(); HloInstruction* instruction = worklist.front();
worklist.pop(); worklist.pop();
new_sequence.push_back(instruction); new_sequence.push_back(instruction);
std::vector<const HloInstruction*>* new_users = std::vector<HloInstruction*>* new_users =
tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
if (new_users != nullptr) { if (new_users != nullptr) {
// This just-scheduled instruction has users which are newly added to // This just-scheduled instruction has users which are newly added to
// the module. Update the number of unscheduled operands and push the // the module. Update the number of unscheduled operands and push the
// newly added instruction to the worklist if it is ready to // newly added instruction to the worklist if it is ready to
// schedule. // schedule.
for (const HloInstruction* new_user : *new_users) { for (HloInstruction* new_user : *new_users) {
unscheduled_operand_count.at(new_user)--; unscheduled_operand_count.at(new_user)--;
CHECK_GE(unscheduled_operand_count.at(new_user), 0); CHECK_GE(unscheduled_operand_count.at(new_user), 0);
if (unscheduled_operand_count.at(new_user) == 0) { if (unscheduled_operand_count.at(new_user) == 0) {

View File

@ -35,14 +35,14 @@ class HloInstructionSequence {
public: public:
HloInstructionSequence() = default; HloInstructionSequence() = default;
explicit HloInstructionSequence( explicit HloInstructionSequence(
absl::Span<const HloInstruction* const> instructions) { absl::Span<HloInstruction* const> instructions) {
for (const HloInstruction* instruction : instructions) { for (HloInstruction* instruction : instructions) {
push_back(instruction); push_back(instruction);
} }
} }
// Adds the instruction to the end of the sequence. // Adds the instruction to the end of the sequence.
void push_back(const HloInstruction* instruction) { void push_back(HloInstruction* instruction) {
instruction_sequence_.push_back(instruction); instruction_sequence_.push_back(instruction);
id_sequence_.push_back(instruction->unique_id()); id_sequence_.push_back(instruction->unique_id());
} }
@ -56,7 +56,7 @@ class HloInstructionSequence {
int64 size() const { return instruction_sequence_.size(); } int64 size() const { return instruction_sequence_.size(); }
// Returns the sequence of HLO instructions. // Returns the sequence of HLO instructions.
const std::vector<const HloInstruction*>& instructions() const { const std::vector<HloInstruction*>& instructions() const {
return instruction_sequence_; return instruction_sequence_;
} }
@ -65,7 +65,7 @@ class HloInstructionSequence {
private: private:
// The sequence as HloInstructions. // The sequence as HloInstructions.
std::vector<const HloInstruction*> instruction_sequence_; std::vector<HloInstruction*> instruction_sequence_;
// The sequence of HLO instructions, represented by their unique IDs. The // The sequence of HLO instructions, represented by their unique IDs. The
// sequence is stored as both HloInstructions and unique IDs because the // sequence is stored as both HloInstructions and unique IDs because the
@ -98,7 +98,7 @@ class HloSchedule {
// Sets the sequence for the given computation to the given sequence. // Sets the sequence for the given computation to the given sequence.
void set_sequence(const HloComputation* computation, void set_sequence(const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence); absl::Span<HloInstruction* const> sequence);
void set_sequence(const HloComputation* computation, void set_sequence(const HloComputation* computation,
HloInstructionSequence sequence); HloInstructionSequence sequence);

View File

@ -56,10 +56,10 @@ ENTRY main {
ParseHloString(module_str)); ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape()); return ShapeUtil::ByteSizeOf(buffer.shape());
})); }));
const std::vector<const HloInstruction*>& entry_schedule = const auto& entry_schedule =
schedule.sequence(module->entry_computation()).instructions(); schedule.sequence(module->entry_computation()).instructions();
EXPECT_EQ(entry_schedule.size(), 6); EXPECT_EQ(entry_schedule.size(), 6);
@ -90,7 +90,7 @@ ENTRY main {
ParseHloString(module_str)); ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape()); return ShapeUtil::ByteSizeOf(buffer.shape());
})); }));
@ -139,7 +139,7 @@ ENTRY main {
ParseHloString(module_str)); ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape()); return ShapeUtil::ByteSizeOf(buffer.shape());
})); }));
@ -183,7 +183,7 @@ ENTRY main {
ParseHloString(module_str)); ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape()); return ShapeUtil::ByteSizeOf(buffer.shape());
})); }));
@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] {
ParseHloString(module_str)); ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), return ShapeUtil::ByteSizeOf(buffer.shape(),
/*pointer_size=*/sizeof(void*)); /*pointer_size=*/sizeof(void*));
})); }));
@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] {
ParseHloString(module_str)); ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule, HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) { ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), return ShapeUtil::ByteSizeOf(buffer.shape(),
/*pointer_size=*/sizeof(void*)); /*pointer_size=*/sizeof(void*));
})); }));