[TF:XLA] Remove a const_cast from HLO rematerialization.

This required making HloSchedule and related classes a sequence of HloInstruction* (instead of const HloInstruction*)

Also, use HloInstructionSequence in a few more places (hlo_rematerialization.cc)

No functional change.

PiperOrigin-RevId: 221868639
This commit is contained in:
A. Unique TensorFlower 2018-11-16 16:06:20 -08:00 committed by TensorFlower Gardener
parent 8930d5aff5
commit 2121365e74
23 changed files with 142 additions and 151 deletions

View File

@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloTestBase {
}
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
HloModule* module,
absl::Span<const HloInstruction* const> instruction_sequence,
HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
int64 alignment = 1) {
HloSchedule schedule(module);
schedule.set_sequence(module->entry_computation(), instruction_sequence);
@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
HloSchedule schedule =
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
module, absl::make_unique<SequentialHloOrdering>(schedule),
ByteSizeOf,
@ -2162,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// nodes are traversed during BufferAssignment.
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(),
/*pointer_size=*/sizeof(void*));
}));
@ -2391,15 +2390,16 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
RunCopyInsertion(module.get());
HloSchedule schedule =
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo schedule for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
schedule.set_sequence(module->entry_computation(),
{input1, weights1, one, output1, while1->operand(0),
while1, input0, weights0, zero, output0,
while0->operand(0), while0, gte0, gte1, root_add});
schedule.set_sequence(
module->entry_computation(),
{input1, weights1, one, output1, while1->mutable_operand(0), while1,
input0, weights0, zero, output0, while0->mutable_operand(0), while0,
gte0, gte1, root_add});
// If this ASSERT fails, we constructed a bogus sequence above and this test
// itself is buggy.

View File

@ -587,9 +587,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// Select an order for emitting the HLO instructions for each
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
HloSchedule schedule,
ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
ScheduleModule(module.get(), BufferSizeBytesFunction(),
DFSMemoryScheduler));
// Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
@ -779,7 +779,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
XLA_VLOG_LINES(2, module->ToString());
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
ScheduleModule(*module, BufferSizeBytesFunction()));
ScheduleModule(module, BufferSizeBytesFunction()));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.

View File

@ -111,7 +111,7 @@ IrEmitter::IrEmitter(
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
const std::vector<const HloInstruction*>* instruction_order) {
const std::vector<HloInstruction*>* instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr);

View File

@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Function*> EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
const std::vector<const HloInstruction*>* instruction_order);
const std::vector<HloInstruction*>* instruction_order);
llvm::IRBuilder<>* b() { return &b_; }

View File

@ -37,7 +37,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
public:
GpuHloOrdering(const HloModule* module,
const StreamAssignment& stream_assignment,
const std::vector<const HloInstruction*>& thunk_launch_order);
const std::vector<HloInstruction*>& thunk_launch_order);
~GpuHloOrdering() override = default;
// Only the entry computation can possibly be sequentially ordered, and only
@ -56,7 +56,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
GpuHloOrdering::GpuHloOrdering(
const HloModule* module, const StreamAssignment& stream_assignment,
const std::vector<const HloInstruction*>& thunk_launch_order)
const std::vector<HloInstruction*>& thunk_launch_order)
: PredecessorHloOrdering(module) {
// The entry computation has a total order when there's only one stream.
if (stream_assignment.StreamCount() == 1) {
@ -150,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering(
// However, if the total order is A,B,D,C,E, then C and E can run
// concurrently.
void BFSLaunchOrder(const HloComputation* computation,
std::vector<const HloInstruction*>* launch_order) {
std::vector<HloInstruction*>* launch_order) {
// This topological sort uses two data structures:
// 1. `incoming_edge_count` which keeps track of the number of incoming
// edges to each HLO;
@ -158,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation,
//
// The sorting algorithm repeatedly pops the top from the queue and deletes
// that HLO from the graph, making more HLOs incoming-edge free.
std::deque<const HloInstruction*> queue;
std::deque<HloInstruction*> queue;
std::unordered_map<const HloInstruction*, int64> incoming_edge_count;
for (const auto& hlo : computation->instructions()) {
for (auto* hlo : computation->instructions()) {
if (hlo->operand_count() == 0) {
queue.push_back(hlo);
} else {
@ -172,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation,
}
while (!queue.empty()) {
const HloInstruction* x = queue.front();
HloInstruction* x = queue.front();
queue.pop_front();
launch_order->push_back(x);
for (const HloInstruction* y : x->users()) {
for (HloInstruction* y : x->users()) {
--incoming_edge_count[y];
if (incoming_edge_count[y] == 0) {
queue.push_back(y);
@ -195,14 +195,14 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
// Initialize thunk_launch_order_, the total order of thunk launches.
const HloComputation* entry_computation = module.entry_computation();
HloComputation* entry_computation = module.entry_computation();
if (stream_assignment.StreamCount() == 1) {
// All kernels are launched on a single stream, so there's no loss of
// concurrency by optimizing for minimal memory usage.
TF_ASSIGN_OR_RETURN(
HloInstructionSequence sequence,
ScheduleComputation(
*entry_computation, [pointer_size](const BufferValue& buffer) {
entry_computation, [pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
}));
schedule->thunk_launch_order_ = sequence.instructions();

View File

@ -46,7 +46,7 @@ class GpuHloSchedule {
// Returns the total order of thunk launches, represented in terms of HLO
// instructions.
const std::vector<const HloInstruction*>& ThunkLaunchOrder() const {
const std::vector<HloInstruction*>& ThunkLaunchOrder() const {
return thunk_launch_order_;
}
@ -60,7 +60,7 @@ class GpuHloSchedule {
private:
GpuHloSchedule();
std::vector<const HloInstruction*> thunk_launch_order_;
std::vector<HloInstruction*> thunk_launch_order_;
std::unique_ptr<HloOrdering> hlo_ordering_;
};

View File

@ -33,7 +33,7 @@ namespace gpu {
class GpuHloScheduleTest : public HloTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
using HloVec = std::vector<HloInstruction*>;
// Pre-canned shapes.
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});

View File

@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands(
ThunkSchedule::ThunkSchedule(
std::unique_ptr<ThunkSequence> thunks,
std::unique_ptr<StreamAssignment> stream_assignment,
const std::vector<const HloInstruction*>& hlo_total_order)
const std::vector<HloInstruction*>& hlo_total_order)
: thunks_(std::move(thunks)),
stream_assignment_(std::move(stream_assignment)) {
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule(
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
}
for (const HloInstruction* hlo : hlo_total_order) {
for (HloInstruction* hlo : hlo_total_order) {
if (hlo_to_thunk.count(hlo)) {
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
}

View File

@ -46,7 +46,7 @@ class ThunkSchedule {
public:
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
std::unique_ptr<StreamAssignment> stream_assignment,
const std::vector<const HloInstruction*>& hlo_total_order);
const std::vector<HloInstruction*>& hlo_total_order);
// Returns the total order of executing all the thunks.
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }

View File

@ -258,7 +258,7 @@ class HeapSimulatorTracker {
// Constructor for testing a single entry computation.
HeapSimulatorTracker(
const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) {
const std::vector<HloInstruction*>& instruction_sequence) {
HloModuleConfig config;
module_ = absl::make_unique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation));
@ -286,7 +286,7 @@ class HeapSimulatorTracker {
// Similar to the single entry computation constructor above, but runs the
// simulation over the entire module.
void RunWholeModule(
const std::vector<const HloInstruction*>& full_module_sequence) {
const std::vector<HloInstruction*>& full_module_sequence) {
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@ -294,7 +294,7 @@ class HeapSimulatorTracker {
HloSchedule schedule(module_.get());
absl::flat_hash_map<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
HloInstruction* instruction = full_module_sequence[i];
schedule.GetOrCreateSequence(instruction->parent())
.push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i;

View File

@ -795,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder(
template <typename HloInstructionPtr>
Status HloComputation::AcceptOrdered(
DfsHloVisitorBase<HloInstructionPtr>* visitor,
const std::vector<const HloInstruction*>& order) const {
const std::vector<HloInstruction*>& order) const {
VLOG(3) << "Accepting visitor with order.";
for (HloInstruction* root : CollectUnreachableRoots()) {
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
@ -825,9 +825,9 @@ Status HloComputation::AcceptOrdered(
// Explicit instantiations.
template Status HloComputation::AcceptOrdered(
DfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
DfsHloVisitor*, const std::vector<HloInstruction*>&) const;
template Status HloComputation::AcceptOrdered(
ConstDfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
ConstDfsHloVisitor*, const std::vector<HloInstruction*>&) const;
Status HloComputation::Accept(
const std::function<Status(HloInstruction*)>& visitor_func) {

View File

@ -301,7 +301,7 @@ class HloComputation {
// be a topological sort of all instructions in the computation.
template <typename HloInstructionPtr>
Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
const std::vector<const HloInstruction*>& order) const;
const std::vector<HloInstruction*>& order) const;
// Same as Accept() above, but the visitor is given as a function.
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);

View File

@ -73,7 +73,7 @@ class ListScheduler {
// Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation.
static StatusOr<HloInstructionSequence> Run(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -98,7 +98,7 @@ class ListScheduler {
// comparison operators.
using Priority = std::pair<int64, int64>;
ListScheduler(const HloComputation& computation,
ListScheduler(HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -111,7 +111,7 @@ class ListScheduler {
// instruction. An HLO instruction "uses" a LogicalBuffer if the
// LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis.
for (auto* instruction : computation.instructions()) {
for (auto* instruction : computation->instructions()) {
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) {
points_to_analysis.GetPointsToSet(operand).ForEachElement(
@ -126,13 +126,13 @@ class ListScheduler {
// Create map containing the number of unscheduled uses (hlo instructions)
// of each logical buffer.
for (auto* instruction : computation.instructions()) {
for (auto* instruction : computation->instructions()) {
for (auto* buffer :
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
unscheduled_use_count_[buffer] = 0;
}
}
for (auto* instruction : computation.instructions()) {
for (auto* instruction : computation->instructions()) {
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
++unscheduled_use_count_[buffer];
}
@ -141,7 +141,7 @@ class ListScheduler {
// Buffers live out of the computation have an implicit use at the end of
// the computation.
for (const LogicalBuffer* live_out_buffer :
points_to_analysis.GetPointsToSet(computation.root_instruction())
points_to_analysis.GetPointsToSet(computation->root_instruction())
.CreateFlattenedSet()) {
++unscheduled_use_count_[live_out_buffer];
}
@ -157,7 +157,7 @@ class ListScheduler {
// HloInstruction, plus some cached metadata, saved for the purposes of making
// BytesFreedIfScheduled fast.
struct ReadyListEntry {
const HloInstruction* instruction;
HloInstruction* instruction;
// The total size of all buffers defined by this instruction.
int64 bytes_defined;
@ -171,7 +171,7 @@ class ListScheduler {
};
// Creates a ReadyListEntry for the given instruction.
ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) {
ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
ReadyListEntry entry;
entry.instruction = instruction;
@ -250,13 +250,13 @@ class ListScheduler {
// Populate the ready list with instructions which have no operands or
// control predecessors.
absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
for (auto* instruction : computation_.instructions()) {
for (auto* instruction : computation_->instructions()) {
// TODO(b/34466113): Replace this and above with successors() or
// predecessors() when these methods are added to HloInstruction.
for (const HloInstruction* user : instruction->users()) {
for (HloInstruction* user : instruction->users()) {
unscheduled_pred_count[user]++;
}
for (const HloInstruction* succ : instruction->control_successors()) {
for (HloInstruction* succ : instruction->control_successors()) {
unscheduled_pred_count[succ]++;
}
}
@ -275,7 +275,7 @@ class ListScheduler {
ready_instructions[inst] = it;
};
for (auto* instruction : computation_.instructions()) {
for (auto* instruction : computation_->instructions()) {
if (instruction->operands().empty() &&
instruction->control_predecessors().empty()) {
add_to_ready_queue(instruction);
@ -287,7 +287,7 @@ class ListScheduler {
// schedule.
auto best_it = ready_queue.end();
--best_it;
const HloInstruction* best = best_it->second.instruction;
HloInstruction* best = best_it->second.instruction;
VLOG(2) << "Schedule instruction: " << best->ToShortString()
<< " Bytes freed: " << best_it->first.first;
ready_queue.erase(best_it);
@ -348,13 +348,13 @@ class ListScheduler {
}
}
}
CHECK_EQ(schedule.size(), computation_.instruction_count());
CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
CHECK_EQ(schedule.size(), computation_->instruction_count());
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
return schedule;
}
const HloComputation& computation_;
HloComputation* computation_;
const TuplePointsToAnalysis& points_to_analysis_;
const LogicalBuffer::SizeFunction& size_function_;
// Computations are analyzed in post-order. When scheduling an instruction
@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes(
}
StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm,
const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
VLOG(2) << "Computation: " << computation.name();
VLOG(2) << "Computation: " << computation->name();
if (algorithm) {
return algorithm(computation, points_to_analysis, size_function,
memory_by_computation);
@ -404,17 +404,17 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
} // namespace
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0;
int64 total_hlos = computation.parent()->instruction_count();
int64 total_hlos = computation->parent()->instruction_count();
absl::flat_hash_map<const HloInstruction*, int64> extra_users;
absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
if (ListScheduler::IgnoreInstruction(*hlo)) {
extra_users[hlo] = 0;
total_sizes[hlo] = 0;
@ -448,8 +448,8 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
}
CHECK_EQ(extra_users.size(), computation.instruction_count());
CHECK_EQ(total_sizes.size(), computation.instruction_count());
CHECK_EQ(extra_users.size(), computation->instruction_count());
CHECK_EQ(total_sizes.size(), computation->instruction_count());
// Construct a total order based on DFS post-order, visiting operands in
// decreasing cumulative extra user order, and next by cumulative size, with a
@ -459,7 +459,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
sequence.push_back(hlo);
return Status::OK();
});
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
&visitor, [&extra_users, &total_sizes](const HloInstruction* a,
const HloInstruction* b) {
if (extra_users[a] != extra_users[b]) {
@ -470,12 +470,12 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
}
return a->name() < b->name();
}));
CHECK_EQ(sequence.size(), computation.instruction_count());
CHECK_EQ(sequence.size(), computation->instruction_count());
return sequence;
} // namespace xla
StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -485,16 +485,16 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
}
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
return HloInstructionSequence(computation.MakeInstructionPostOrder());
return HloInstructionSequence(computation->MakeInstructionPostOrder());
}
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -513,7 +513,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 list_memory,
HeapSimulator::MinimumMemoryForComputation(
computation, list_sequence, points_to_analysis,
*computation, list_sequence, points_to_analysis,
size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
@ -522,7 +522,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
size_function, memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
HeapSimulator::MinimumMemoryForComputation(
computation, dfs_sequence, points_to_analysis,
*computation, dfs_sequence, points_to_analysis,
size_function, &memory_by_computation));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
@ -532,7 +532,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
HeapSimulator::MinimumMemoryForComputation(
computation, post_order_sequence, points_to_analysis,
*computation, post_order_sequence, points_to_analysis,
size_function, &memory_by_computation));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
@ -555,17 +555,17 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
}
StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm) {
HloSchedule schedule(&module);
HloSchedule schedule(module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
TuplePointsToAnalysis::Run(module));
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) {
for (auto* computation : module->MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
HeapSimulator::MinimumMemoryForComputation(
@ -583,11 +583,11 @@ StatusOr<HloSchedule> ScheduleModule(
}
StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
HloComputation* computation,
const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation());
CHECK(!computation->IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
TuplePointsToAnalysis::Run(computation->parent()));
absl::flat_hash_map<const HloComputation*, int64> empty_map;
return ScheduleComputationHelper(computation, *points_to_analysis,
size_function, nullptr, empty_map);
@ -600,7 +600,7 @@ HloMemoryScheduler::HloMemoryScheduler(
StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
ScheduleModule(*module, size_function_, algorithm_));
ScheduleModule(module, size_function_, algorithm_));
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
return true;
}

View File

@ -36,14 +36,14 @@ namespace xla {
// that describes buffer aliasing, together with a target-specific size function
// that maps a tensor's logical size to its padded size.
typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
HloComputation*, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
const absl::flat_hash_map<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -51,7 +51,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
// DFS-order scheduler
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -59,7 +59,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
// Naive Post Order scheduler
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -69,7 +69,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const absl::flat_hash_map<const HloComputation*, int64>&
@ -79,13 +79,13 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {});
// Computes the schedule for a single computation.
// Currently only used by the GPU backend.
StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
HloComputation* computation,
const LogicalBuffer::SizeFunction& size_function);
// A pass which schedules the HLO instructions in a module. The HloModule's

View File

@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
TF_ASSERT_OK(module->schedule().Verify());
// Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence =
const std::vector<HloInstruction*>& sequence =
module->schedule().sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
@ -124,9 +124,9 @@ ENTRY root {
};
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, size_fn, ListMemoryScheduler));
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence =
const std::vector<HloInstruction*>& sequence =
schedule.sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
@ -175,12 +175,13 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(*module,
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(
buffer.shape(), TUPLE_SIZE);
},
ListMemoryScheduler));
ScheduleModule(
module.get(),
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(),
TUPLE_SIZE);
},
ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
@ -225,12 +226,12 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(*module,
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(
buffer.shape(), 2);
},
ListMemoryScheduler));
ScheduleModule(
module.get(),
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
},
ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
};
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, size_fn, ListMemoryScheduler));
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
EXPECT_EQ(module->entry_computation()->instruction_count(),

View File

@ -104,11 +104,7 @@ class HloModule {
HloCloneContext* context = nullptr);
// Return a pointer to the entry computation of the module.
const HloComputation* entry_computation() const {
CHECK_NE(nullptr, entry_computation_);
return entry_computation_;
}
HloComputation* entry_computation() {
HloComputation* entry_computation() const {
CHECK_NE(nullptr, entry_computation_);
return entry_computation_;
}

View File

@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() {
// Create a map from instruction to its order position.
TF_DCHECK_OK(schedule_.Verify());
for (const auto& computation_sequence : schedule_.sequences()) {
const std::vector<const HloInstruction*>& order =
computation_sequence.second.instructions();
const auto& order = computation_sequence.second.instructions();
for (int i = 0; i < order.size(); ++i) {
InsertOrDie(&order_position_, order[i], i);
}

View File

@ -47,11 +47,11 @@ const double kF16max = 65504;
// Creates and returns a schedule created using the order of the instructions in
// the HloComputation::instructions() vectors in the module.
HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
HloSchedule schedule(module);
for (const HloComputation* computation : module->computations()) {
for (HloComputation* computation : module->computations()) {
if (!computation->IsFusionComputation()) {
for (const HloInstruction* instruction : computation->instructions()) {
for (HloInstruction* instruction : computation->instructions()) {
schedule.GetOrCreateSequence(computation).push_back(instruction);
}
}

View File

@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector<Item*, 3>;
// before arbitrary elements.
class InstructionList {
public:
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
explicit InstructionList(const HloInstructionSequence& order) {
int64 position = 0;
Item* last = nullptr;
for (const HloInstruction* inst : order) {
for (HloInstruction* inst : order.instructions()) {
// Add a new item to the linked list.
Item* item = new Item;
item->next = nullptr;
@ -151,7 +151,7 @@ class InstructionList {
// to be monotonically increasing through the list, and so is still useful
// for quickly(-ish) determining the order of arbitrary instructions in
// the list.
item->instruction = const_cast<HloInstruction*>(inst);
item->instruction = inst;
item->position = position;
position++;
@ -927,7 +927,7 @@ Item* PickRematerializationCandidate(
StatusOr<int64> HloRematerialization::ComputePeakMemory(
const HloComputation* computation,
const std::vector<const HloInstruction*>& order) const {
const HloInstructionSequence& order) const {
InstructionList instruction_list(order);
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
instruction_list);
@ -971,8 +971,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
CHECK(!ContainsKey(rematerialized_computations_, computation));
InstructionList instruction_list(
schedule->sequence(computation).instructions());
InstructionList instruction_list(schedule->sequence(computation));
MemoryUsageTracker memory_tracker(computation, size_function_,
*points_to_analysis_, instruction_list);
bool changed = false;
@ -1184,7 +1183,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
sequence.clear();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
HloInstruction* instruction = item->instruction;
sequence.push_back(instruction);
}
rematerialized_computations_.insert(computation);
@ -1235,10 +1234,8 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
ComputePeakMemory(node.computation(),
module->schedule()
.sequence(node.computation())
.instructions()));
ComputePeakMemory(node.computation(), module->schedule().sequence(
node.computation())));
}
return Status::OK();
},

View File

@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass {
// peak memory is the maximum total size of all live HLO instruction values at
// any program point. 'order' is the order in which the HLO instructions will
// be emitted which is used to determine lifespans of HLO values.
StatusOr<int64> ComputePeakMemory(
const HloComputation* computation,
const std::vector<const HloInstruction*>& order) const;
StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
const HloInstructionSequence& order) const;
// Returns the peak memory usage of the called computations for the given
// instruction. Zero is returned if the instruction calls no computations.

View File

@ -46,8 +46,8 @@ namespace xla {
<< "No computation exists in HLO module with id " << computation_id;
const HloComputation* computation = comp_it->second;
absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) {
absl::flat_hash_map<int64, HloInstruction*> id_to_instruction;
for (HloInstruction* instruction : computation->instructions()) {
id_to_instruction[instruction->unique_id()] = instruction;
}
@ -81,9 +81,8 @@ StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
return std::move(proto);
}
void HloSchedule::set_sequence(
const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence) {
void HloSchedule::set_sequence(const HloComputation* computation,
absl::Span<HloInstruction* const> sequence) {
set_sequence(computation, HloInstructionSequence(sequence));
}
@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule(
const HloComputation* computation) {
// Map from unique ID to HloInstruction pointer for instructions in the
// computation.
absl::flat_hash_map<int, const HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) {
absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
for (HloInstruction* instruction : computation->instructions()) {
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
}
@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule(
// Map from HloInstruction X to newly added instructions (instruction is in
// computation, but not in schedule) which use X. If an instruction is not in
// the map, then it has no users which are newly added instructions.
absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>>
absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
new_instruction_uses;
// For each newly added instruction, this is the count of the instruction's
@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule(
// Create a worklist of newly added instructions which are ready to be added
// to the schedule. Initialize worklist with those that have zero operands.
std::queue<const HloInstruction*> worklist;
std::queue<HloInstruction*> worklist;
for (const HloInstruction* instruction : computation->instructions()) {
for (HloInstruction* instruction : computation->instructions()) {
if (ids_in_schedule.count(instruction->unique_id()) == 0) {
// This is a newly added instruction which is not in the schedule.
if (instruction->operands().empty()) {
@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule(
// Lambda which schedules all instructions on the worklist.
auto schedule_worklist = [&]() {
while (!worklist.empty()) {
const HloInstruction* instruction = worklist.front();
HloInstruction* instruction = worklist.front();
worklist.pop();
new_sequence.push_back(instruction);
std::vector<const HloInstruction*>* new_users =
std::vector<HloInstruction*>* new_users =
tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
if (new_users != nullptr) {
// This just-scheduled instruction has users which are newly added to
// the module. Update the number of unscheduled operands and push the
// newly added instruction to the worklist if it is ready to
// schedule.
for (const HloInstruction* new_user : *new_users) {
for (HloInstruction* new_user : *new_users) {
unscheduled_operand_count.at(new_user)--;
CHECK_GE(unscheduled_operand_count.at(new_user), 0);
if (unscheduled_operand_count.at(new_user) == 0) {

View File

@ -35,14 +35,14 @@ class HloInstructionSequence {
public:
HloInstructionSequence() = default;
explicit HloInstructionSequence(
absl::Span<const HloInstruction* const> instructions) {
for (const HloInstruction* instruction : instructions) {
absl::Span<HloInstruction* const> instructions) {
for (HloInstruction* instruction : instructions) {
push_back(instruction);
}
}
// Adds the instruction to the end of the sequence.
void push_back(const HloInstruction* instruction) {
void push_back(HloInstruction* instruction) {
instruction_sequence_.push_back(instruction);
id_sequence_.push_back(instruction->unique_id());
}
@ -56,7 +56,7 @@ class HloInstructionSequence {
int64 size() const { return instruction_sequence_.size(); }
// Returns the sequence of HLO instructions.
const std::vector<const HloInstruction*>& instructions() const {
const std::vector<HloInstruction*>& instructions() const {
return instruction_sequence_;
}
@ -65,7 +65,7 @@ class HloInstructionSequence {
private:
// The sequence as HloInstructions.
std::vector<const HloInstruction*> instruction_sequence_;
std::vector<HloInstruction*> instruction_sequence_;
// The sequence of HLO instructions, represented by their unique IDs. The
// sequence is stored as both HloInstructions and unique IDs because the
@ -98,7 +98,7 @@ class HloSchedule {
// Sets the sequence for the given computation to the given sequence.
void set_sequence(const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence);
absl::Span<HloInstruction* const> sequence);
void set_sequence(const HloComputation* computation,
HloInstructionSequence sequence);

View File

@ -56,10 +56,10 @@ ENTRY main {
ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
const std::vector<const HloInstruction*>& entry_schedule =
const auto& entry_schedule =
schedule.sequence(module->entry_computation()).instructions();
EXPECT_EQ(entry_schedule.size(), 6);
@ -90,7 +90,7 @@ ENTRY main {
ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
@ -139,7 +139,7 @@ ENTRY main {
ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
@ -183,7 +183,7 @@ ENTRY main {
ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] {
ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(),
/*pointer_size=*/sizeof(void*));
}));
@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] {
ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(*module, [](const BufferValue& buffer) {
ScheduleModule(module.get(), [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(),
/*pointer_size=*/sizeof(void*));
}));