[TF:XLA] Remove a const_cast from HLO rematerialization.
This required making HloSchedule and related classes a sequence of HloInstruction* (instead of const HloInstruction*) Also, use HloInstructionSequence in a few more places (hlo_rematerialization.cc) No functional change. PiperOrigin-RevId: 221868639
This commit is contained in:
parent
8930d5aff5
commit
2121365e74
@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
}
|
||||
|
||||
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
|
||||
HloModule* module,
|
||||
absl::Span<const HloInstruction* const> instruction_sequence,
|
||||
HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
|
||||
int64 alignment = 1) {
|
||||
HloSchedule schedule(module);
|
||||
schedule.set_sequence(module->entry_computation(), instruction_sequence);
|
||||
@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloTestBase {
|
||||
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
|
||||
int64 alignment = 1) {
|
||||
HloSchedule schedule =
|
||||
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
|
||||
ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
|
||||
return BufferAssigner::Run(
|
||||
module, absl::make_unique<SequentialHloOrdering>(schedule),
|
||||
ByteSizeOf,
|
||||
@ -2162,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
||||
// nodes are traversed during BufferAssignment.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||
/*pointer_size=*/sizeof(void*));
|
||||
}));
|
||||
@ -2391,15 +2390,16 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
||||
RunCopyInsertion(module.get());
|
||||
|
||||
HloSchedule schedule =
|
||||
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
|
||||
ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
|
||||
|
||||
// To trigger b/38494731, we want a specific Hlo schedule for the
|
||||
// root computation, so we overwrite that entry with a manually
|
||||
// crafted sequence.
|
||||
schedule.set_sequence(module->entry_computation(),
|
||||
{input1, weights1, one, output1, while1->operand(0),
|
||||
while1, input0, weights0, zero, output0,
|
||||
while0->operand(0), while0, gte0, gte1, root_add});
|
||||
schedule.set_sequence(
|
||||
module->entry_computation(),
|
||||
{input1, weights1, one, output1, while1->mutable_operand(0), while1,
|
||||
input0, weights0, zero, output0, while0->mutable_operand(0), while0,
|
||||
gte0, gte1, root_add});
|
||||
|
||||
// If this ASSERT fails, we constructed a bogus sequence above and this test
|
||||
// itself is buggy.
|
||||
|
@ -587,9 +587,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
// Select an order for emitting the HLO instructions for each
|
||||
// computation. Using this sequence enables tighter buffer liveness analysis
|
||||
// and reduced memory usage (as compared to using DependencyHloOrdering).
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
|
||||
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
||||
ScheduleModule(module.get(), BufferSizeBytesFunction(),
|
||||
DFSMemoryScheduler));
|
||||
|
||||
// Run buffer allocation on the HLO graph.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -779,7 +779,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
||||
ScheduleModule(*module, BufferSizeBytesFunction()));
|
||||
ScheduleModule(module, BufferSizeBytesFunction()));
|
||||
|
||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||
// temporary buffers are required to run the computation.
|
||||
|
@ -111,7 +111,7 @@ IrEmitter::IrEmitter(
|
||||
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
||||
HloComputation* computation, const string& function_name_prefix,
|
||||
bool is_top_level_computation,
|
||||
const std::vector<const HloInstruction*>* instruction_order) {
|
||||
const std::vector<HloInstruction*>* instruction_order) {
|
||||
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
|
||||
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
|
||||
<< "]; ordered? " << (instruction_order != nullptr);
|
||||
|
@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
StatusOr<llvm::Function*> EmitComputation(
|
||||
HloComputation* computation, const string& function_name_prefix,
|
||||
bool is_top_level_computation,
|
||||
const std::vector<const HloInstruction*>* instruction_order);
|
||||
const std::vector<HloInstruction*>* instruction_order);
|
||||
|
||||
llvm::IRBuilder<>* b() { return &b_; }
|
||||
|
||||
|
@ -37,7 +37,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
|
||||
public:
|
||||
GpuHloOrdering(const HloModule* module,
|
||||
const StreamAssignment& stream_assignment,
|
||||
const std::vector<const HloInstruction*>& thunk_launch_order);
|
||||
const std::vector<HloInstruction*>& thunk_launch_order);
|
||||
~GpuHloOrdering() override = default;
|
||||
|
||||
// Only the entry computation can possibly be sequentially ordered, and only
|
||||
@ -56,7 +56,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
|
||||
|
||||
GpuHloOrdering::GpuHloOrdering(
|
||||
const HloModule* module, const StreamAssignment& stream_assignment,
|
||||
const std::vector<const HloInstruction*>& thunk_launch_order)
|
||||
const std::vector<HloInstruction*>& thunk_launch_order)
|
||||
: PredecessorHloOrdering(module) {
|
||||
// The entry computation has a total order when there's only one stream.
|
||||
if (stream_assignment.StreamCount() == 1) {
|
||||
@ -150,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering(
|
||||
// However, if the total order is A,B,D,C,E, then C and E can run
|
||||
// concurrently.
|
||||
void BFSLaunchOrder(const HloComputation* computation,
|
||||
std::vector<const HloInstruction*>* launch_order) {
|
||||
std::vector<HloInstruction*>* launch_order) {
|
||||
// This topological sort uses two data structures:
|
||||
// 1. `incoming_edge_count` which keeps track of the number of incoming
|
||||
// edges to each HLO;
|
||||
@ -158,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation,
|
||||
//
|
||||
// The sorting algorithm repeatedly pops the top from the queue and deletes
|
||||
// that HLO from the graph, making more HLOs incoming-edge free.
|
||||
std::deque<const HloInstruction*> queue;
|
||||
std::deque<HloInstruction*> queue;
|
||||
std::unordered_map<const HloInstruction*, int64> incoming_edge_count;
|
||||
for (const auto& hlo : computation->instructions()) {
|
||||
for (auto* hlo : computation->instructions()) {
|
||||
if (hlo->operand_count() == 0) {
|
||||
queue.push_back(hlo);
|
||||
} else {
|
||||
@ -172,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation,
|
||||
}
|
||||
|
||||
while (!queue.empty()) {
|
||||
const HloInstruction* x = queue.front();
|
||||
HloInstruction* x = queue.front();
|
||||
queue.pop_front();
|
||||
launch_order->push_back(x);
|
||||
for (const HloInstruction* y : x->users()) {
|
||||
for (HloInstruction* y : x->users()) {
|
||||
--incoming_edge_count[y];
|
||||
if (incoming_edge_count[y] == 0) {
|
||||
queue.push_back(y);
|
||||
@ -195,14 +195,14 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
|
||||
std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
|
||||
|
||||
// Initialize thunk_launch_order_, the total order of thunk launches.
|
||||
const HloComputation* entry_computation = module.entry_computation();
|
||||
HloComputation* entry_computation = module.entry_computation();
|
||||
if (stream_assignment.StreamCount() == 1) {
|
||||
// All kernels are launched on a single stream, so there's no loss of
|
||||
// concurrency by optimizing for minimal memory usage.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionSequence sequence,
|
||||
ScheduleComputation(
|
||||
*entry_computation, [pointer_size](const BufferValue& buffer) {
|
||||
entry_computation, [pointer_size](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
|
||||
}));
|
||||
schedule->thunk_launch_order_ = sequence.instructions();
|
||||
|
@ -46,7 +46,7 @@ class GpuHloSchedule {
|
||||
|
||||
// Returns the total order of thunk launches, represented in terms of HLO
|
||||
// instructions.
|
||||
const std::vector<const HloInstruction*>& ThunkLaunchOrder() const {
|
||||
const std::vector<HloInstruction*>& ThunkLaunchOrder() const {
|
||||
return thunk_launch_order_;
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ class GpuHloSchedule {
|
||||
private:
|
||||
GpuHloSchedule();
|
||||
|
||||
std::vector<const HloInstruction*> thunk_launch_order_;
|
||||
std::vector<HloInstruction*> thunk_launch_order_;
|
||||
std::unique_ptr<HloOrdering> hlo_ordering_;
|
||||
};
|
||||
|
||||
|
@ -33,7 +33,7 @@ namespace gpu {
|
||||
|
||||
class GpuHloScheduleTest : public HloTestBase {
|
||||
protected:
|
||||
using HloVec = std::vector<const HloInstruction*>;
|
||||
using HloVec = std::vector<HloInstruction*>;
|
||||
|
||||
// Pre-canned shapes.
|
||||
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
|
@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands(
|
||||
ThunkSchedule::ThunkSchedule(
|
||||
std::unique_ptr<ThunkSequence> thunks,
|
||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
||||
const std::vector<const HloInstruction*>& hlo_total_order)
|
||||
const std::vector<HloInstruction*>& hlo_total_order)
|
||||
: thunks_(std::move(thunks)),
|
||||
stream_assignment_(std::move(stream_assignment)) {
|
||||
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
||||
@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule(
|
||||
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
|
||||
}
|
||||
|
||||
for (const HloInstruction* hlo : hlo_total_order) {
|
||||
for (HloInstruction* hlo : hlo_total_order) {
|
||||
if (hlo_to_thunk.count(hlo)) {
|
||||
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ class ThunkSchedule {
|
||||
public:
|
||||
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
|
||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
||||
const std::vector<const HloInstruction*>& hlo_total_order);
|
||||
const std::vector<HloInstruction*>& hlo_total_order);
|
||||
|
||||
// Returns the total order of executing all the thunks.
|
||||
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
|
||||
|
@ -258,7 +258,7 @@ class HeapSimulatorTracker {
|
||||
// Constructor for testing a single entry computation.
|
||||
HeapSimulatorTracker(
|
||||
const string& name, std::unique_ptr<HloComputation> computation,
|
||||
const std::vector<const HloInstruction*>& instruction_sequence) {
|
||||
const std::vector<HloInstruction*>& instruction_sequence) {
|
||||
HloModuleConfig config;
|
||||
module_ = absl::make_unique<HloModule>(name, config);
|
||||
module_->AddEntryComputation(std::move(computation));
|
||||
@ -286,7 +286,7 @@ class HeapSimulatorTracker {
|
||||
// Similar to the single entry computation constructor above, but runs the
|
||||
// simulation over the entire module.
|
||||
void RunWholeModule(
|
||||
const std::vector<const HloInstruction*>& full_module_sequence) {
|
||||
const std::vector<HloInstruction*>& full_module_sequence) {
|
||||
points_to_analysis_ =
|
||||
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||
|
||||
@ -294,7 +294,7 @@ class HeapSimulatorTracker {
|
||||
HloSchedule schedule(module_.get());
|
||||
absl::flat_hash_map<const HloInstruction*, int> reverse_position;
|
||||
for (int i = 0; i < full_module_sequence.size(); ++i) {
|
||||
const HloInstruction* instruction = full_module_sequence[i];
|
||||
HloInstruction* instruction = full_module_sequence[i];
|
||||
schedule.GetOrCreateSequence(instruction->parent())
|
||||
.push_back(instruction);
|
||||
reverse_position[instruction] = full_module_sequence.size() - i;
|
||||
|
@ -795,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder(
|
||||
template <typename HloInstructionPtr>
|
||||
Status HloComputation::AcceptOrdered(
|
||||
DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
||||
const std::vector<const HloInstruction*>& order) const {
|
||||
const std::vector<HloInstruction*>& order) const {
|
||||
VLOG(3) << "Accepting visitor with order.";
|
||||
for (HloInstruction* root : CollectUnreachableRoots()) {
|
||||
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
|
||||
@ -825,9 +825,9 @@ Status HloComputation::AcceptOrdered(
|
||||
|
||||
// Explicit instantiations.
|
||||
template Status HloComputation::AcceptOrdered(
|
||||
DfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
|
||||
DfsHloVisitor*, const std::vector<HloInstruction*>&) const;
|
||||
template Status HloComputation::AcceptOrdered(
|
||||
ConstDfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
|
||||
ConstDfsHloVisitor*, const std::vector<HloInstruction*>&) const;
|
||||
|
||||
Status HloComputation::Accept(
|
||||
const std::function<Status(HloInstruction*)>& visitor_func) {
|
||||
|
@ -301,7 +301,7 @@ class HloComputation {
|
||||
// be a topological sort of all instructions in the computation.
|
||||
template <typename HloInstructionPtr>
|
||||
Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
||||
const std::vector<const HloInstruction*>& order) const;
|
||||
const std::vector<HloInstruction*>& order) const;
|
||||
|
||||
// Same as Accept() above, but the visitor is given as a function.
|
||||
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
|
||||
|
@ -73,7 +73,7 @@ class ListScheduler {
|
||||
// Construct and return a memory-minimizing sequence of HLO instructions
|
||||
// containing the given HLO computation.
|
||||
static StatusOr<HloInstructionSequence> Run(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -98,7 +98,7 @@ class ListScheduler {
|
||||
// comparison operators.
|
||||
using Priority = std::pair<int64, int64>;
|
||||
|
||||
ListScheduler(const HloComputation& computation,
|
||||
ListScheduler(HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -111,7 +111,7 @@ class ListScheduler {
|
||||
// instruction. An HLO instruction "uses" a LogicalBuffer if the
|
||||
// LogicalBuffer is in an operand of the instruction as indicated by
|
||||
// points-to analysis.
|
||||
for (auto* instruction : computation.instructions()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
|
||||
for (auto* operand : instruction->operands()) {
|
||||
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
||||
@ -126,13 +126,13 @@ class ListScheduler {
|
||||
|
||||
// Create map containing the number of unscheduled uses (hlo instructions)
|
||||
// of each logical buffer.
|
||||
for (auto* instruction : computation.instructions()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
for (auto* buffer :
|
||||
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
||||
unscheduled_use_count_[buffer] = 0;
|
||||
}
|
||||
}
|
||||
for (auto* instruction : computation.instructions()) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
|
||||
++unscheduled_use_count_[buffer];
|
||||
}
|
||||
@ -141,7 +141,7 @@ class ListScheduler {
|
||||
// Buffers live out of the computation have an implicit use at the end of
|
||||
// the computation.
|
||||
for (const LogicalBuffer* live_out_buffer :
|
||||
points_to_analysis.GetPointsToSet(computation.root_instruction())
|
||||
points_to_analysis.GetPointsToSet(computation->root_instruction())
|
||||
.CreateFlattenedSet()) {
|
||||
++unscheduled_use_count_[live_out_buffer];
|
||||
}
|
||||
@ -157,7 +157,7 @@ class ListScheduler {
|
||||
// HloInstruction, plus some cached metadata, saved for the purposes of making
|
||||
// BytesFreedIfScheduled fast.
|
||||
struct ReadyListEntry {
|
||||
const HloInstruction* instruction;
|
||||
HloInstruction* instruction;
|
||||
|
||||
// The total size of all buffers defined by this instruction.
|
||||
int64 bytes_defined;
|
||||
@ -171,7 +171,7 @@ class ListScheduler {
|
||||
};
|
||||
|
||||
// Creates a ReadyListEntry for the given instruction.
|
||||
ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) {
|
||||
ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
|
||||
ReadyListEntry entry;
|
||||
entry.instruction = instruction;
|
||||
|
||||
@ -250,13 +250,13 @@ class ListScheduler {
|
||||
// Populate the ready list with instructions which have no operands or
|
||||
// control predecessors.
|
||||
absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
|
||||
for (auto* instruction : computation_.instructions()) {
|
||||
for (auto* instruction : computation_->instructions()) {
|
||||
// TODO(b/34466113): Replace this and above with successors() or
|
||||
// predecessors() when these methods are added to HloInstruction.
|
||||
for (const HloInstruction* user : instruction->users()) {
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
unscheduled_pred_count[user]++;
|
||||
}
|
||||
for (const HloInstruction* succ : instruction->control_successors()) {
|
||||
for (HloInstruction* succ : instruction->control_successors()) {
|
||||
unscheduled_pred_count[succ]++;
|
||||
}
|
||||
}
|
||||
@ -275,7 +275,7 @@ class ListScheduler {
|
||||
ready_instructions[inst] = it;
|
||||
};
|
||||
|
||||
for (auto* instruction : computation_.instructions()) {
|
||||
for (auto* instruction : computation_->instructions()) {
|
||||
if (instruction->operands().empty() &&
|
||||
instruction->control_predecessors().empty()) {
|
||||
add_to_ready_queue(instruction);
|
||||
@ -287,7 +287,7 @@ class ListScheduler {
|
||||
// schedule.
|
||||
auto best_it = ready_queue.end();
|
||||
--best_it;
|
||||
const HloInstruction* best = best_it->second.instruction;
|
||||
HloInstruction* best = best_it->second.instruction;
|
||||
VLOG(2) << "Schedule instruction: " << best->ToShortString()
|
||||
<< " Bytes freed: " << best_it->first.first;
|
||||
ready_queue.erase(best_it);
|
||||
@ -348,13 +348,13 @@ class ListScheduler {
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_EQ(schedule.size(), computation_.instruction_count());
|
||||
CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
|
||||
CHECK_EQ(schedule.size(), computation_->instruction_count());
|
||||
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
|
||||
|
||||
return schedule;
|
||||
}
|
||||
|
||||
const HloComputation& computation_;
|
||||
HloComputation* computation_;
|
||||
const TuplePointsToAnalysis& points_to_analysis_;
|
||||
const LogicalBuffer::SizeFunction& size_function_;
|
||||
// Computations are analyzed in post-order. When scheduling an instruction
|
||||
@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes(
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionSequence> ScheduleComputationHelper(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const MemorySchedulerAlgorithm& algorithm,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
memory_by_computation) {
|
||||
VLOG(2) << "Computation: " << computation.name();
|
||||
VLOG(2) << "Computation: " << computation->name();
|
||||
if (algorithm) {
|
||||
return algorithm(computation, points_to_analysis, size_function,
|
||||
memory_by_computation);
|
||||
@ -404,17 +404,17 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
|
||||
} // namespace
|
||||
|
||||
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
memory_by_computation) {
|
||||
// These variables are a hack to prevent overflows.
|
||||
int64 cumulative_total_size = 0;
|
||||
int64 total_hlos = computation.parent()->instruction_count();
|
||||
int64 total_hlos = computation->parent()->instruction_count();
|
||||
absl::flat_hash_map<const HloInstruction*, int64> extra_users;
|
||||
absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
|
||||
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
|
||||
for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
|
||||
if (ListScheduler::IgnoreInstruction(*hlo)) {
|
||||
extra_users[hlo] = 0;
|
||||
total_sizes[hlo] = 0;
|
||||
@ -448,8 +448,8 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
|
||||
extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
|
||||
}
|
||||
CHECK_EQ(extra_users.size(), computation.instruction_count());
|
||||
CHECK_EQ(total_sizes.size(), computation.instruction_count());
|
||||
CHECK_EQ(extra_users.size(), computation->instruction_count());
|
||||
CHECK_EQ(total_sizes.size(), computation->instruction_count());
|
||||
|
||||
// Construct a total order based on DFS post-order, visiting operands in
|
||||
// decreasing cumulative extra user order, and next by cumulative size, with a
|
||||
@ -459,7 +459,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
sequence.push_back(hlo);
|
||||
return Status::OK();
|
||||
});
|
||||
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
|
||||
TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
|
||||
&visitor, [&extra_users, &total_sizes](const HloInstruction* a,
|
||||
const HloInstruction* b) {
|
||||
if (extra_users[a] != extra_users[b]) {
|
||||
@ -470,12 +470,12 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
}
|
||||
return a->name() < b->name();
|
||||
}));
|
||||
CHECK_EQ(sequence.size(), computation.instruction_count());
|
||||
CHECK_EQ(sequence.size(), computation->instruction_count());
|
||||
return sequence;
|
||||
} // namespace xla
|
||||
|
||||
StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -485,16 +485,16 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
memory_by_computation) {
|
||||
return HloInstructionSequence(computation.MakeInstructionPostOrder());
|
||||
return HloInstructionSequence(computation->MakeInstructionPostOrder());
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -513,7 +513,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
memory_by_computation));
|
||||
TF_ASSIGN_OR_RETURN(const int64 list_memory,
|
||||
HeapSimulator::MinimumMemoryForComputation(
|
||||
computation, list_sequence, points_to_analysis,
|
||||
*computation, list_sequence, points_to_analysis,
|
||||
size_function, &memory_by_computation));
|
||||
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
|
||||
|
||||
@ -522,7 +522,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
size_function, memory_by_computation));
|
||||
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
|
||||
HeapSimulator::MinimumMemoryForComputation(
|
||||
computation, dfs_sequence, points_to_analysis,
|
||||
*computation, dfs_sequence, points_to_analysis,
|
||||
size_function, &memory_by_computation));
|
||||
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
|
||||
|
||||
@ -532,7 +532,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
memory_by_computation));
|
||||
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
|
||||
HeapSimulator::MinimumMemoryForComputation(
|
||||
computation, post_order_sequence, points_to_analysis,
|
||||
*computation, post_order_sequence, points_to_analysis,
|
||||
size_function, &memory_by_computation));
|
||||
VLOG(2) << "Min-memory post order sequence: "
|
||||
<< HumanReadableNumBytes(post_order_memory);
|
||||
@ -555,17 +555,17 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
}
|
||||
|
||||
StatusOr<HloSchedule> ScheduleModule(
|
||||
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
|
||||
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
|
||||
const MemorySchedulerAlgorithm& algorithm) {
|
||||
HloSchedule schedule(&module);
|
||||
HloSchedule schedule(module);
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(&module));
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
|
||||
for (const auto* computation : module.MakeComputationPostOrder()) {
|
||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||
if (!computation->IsFusionComputation()) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
||||
ScheduleComputationHelper(
|
||||
*computation, *points_to_analysis, size_function,
|
||||
computation, *points_to_analysis, size_function,
|
||||
algorithm, memory_by_computation));
|
||||
memory_by_computation[computation] =
|
||||
HeapSimulator::MinimumMemoryForComputation(
|
||||
@ -583,11 +583,11 @@ StatusOr<HloSchedule> ScheduleModule(
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionSequence> ScheduleComputation(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const LogicalBuffer::SizeFunction& size_function) {
|
||||
CHECK(!computation.IsFusionComputation());
|
||||
CHECK(!computation->IsFusionComputation());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(computation.parent()));
|
||||
TuplePointsToAnalysis::Run(computation->parent()));
|
||||
absl::flat_hash_map<const HloComputation*, int64> empty_map;
|
||||
return ScheduleComputationHelper(computation, *points_to_analysis,
|
||||
size_function, nullptr, empty_map);
|
||||
@ -600,7 +600,7 @@ HloMemoryScheduler::HloMemoryScheduler(
|
||||
|
||||
StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
||||
ScheduleModule(*module, size_function_, algorithm_));
|
||||
ScheduleModule(module, size_function_, algorithm_));
|
||||
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
|
||||
return true;
|
||||
}
|
||||
|
@ -36,14 +36,14 @@ namespace xla {
|
||||
// that describes buffer aliasing, together with a target-specific size function
|
||||
// that maps a tensor's logical size to its padded size.
|
||||
typedef std::function<StatusOr<HloInstructionSequence>(
|
||||
const HloComputation&, const TuplePointsToAnalysis&,
|
||||
HloComputation*, const TuplePointsToAnalysis&,
|
||||
const LogicalBuffer::SizeFunction&,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&)>
|
||||
MemorySchedulerAlgorithm;
|
||||
|
||||
// List scheduler
|
||||
StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -51,7 +51,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
||||
|
||||
// DFS-order scheduler
|
||||
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -59,7 +59,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
|
||||
// Naive Post Order scheduler
|
||||
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -69,7 +69,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
||||
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
|
||||
// not accounting for fragmentation.
|
||||
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const LogicalBuffer::SizeFunction& size_function,
|
||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||
@ -79,13 +79,13 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||
// the computation. size_function is the function returning the number of bytes
|
||||
// required for a LogicalBuffer.
|
||||
StatusOr<HloSchedule> ScheduleModule(
|
||||
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
|
||||
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
|
||||
const MemorySchedulerAlgorithm& algorithm = {});
|
||||
|
||||
// Computes the schedule for a single computation.
|
||||
// Currently only used by the GPU backend.
|
||||
StatusOr<HloInstructionSequence> ScheduleComputation(
|
||||
const HloComputation& computation,
|
||||
HloComputation* computation,
|
||||
const LogicalBuffer::SizeFunction& size_function);
|
||||
|
||||
// A pass which schedules the HLO instructions in a module. The HloModule's
|
||||
|
@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
|
||||
TF_ASSERT_OK(module->schedule().Verify());
|
||||
|
||||
// Verify that all instructions are in the sequence.
|
||||
const std::vector<const HloInstruction*>& sequence =
|
||||
const std::vector<HloInstruction*>& sequence =
|
||||
module->schedule().sequence(module->entry_computation()).instructions();
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
|
||||
|
||||
@ -124,9 +124,9 @@ ENTRY root {
|
||||
};
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, size_fn, ListMemoryScheduler));
|
||||
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
|
||||
// Verify that all instructions are in the sequence.
|
||||
const std::vector<const HloInstruction*>& sequence =
|
||||
const std::vector<HloInstruction*>& sequence =
|
||||
schedule.sequence(module->entry_computation()).instructions();
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
|
||||
|
||||
@ -175,12 +175,13 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
|
||||
ScheduleModule(*module,
|
||||
[](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(
|
||||
buffer.shape(), TUPLE_SIZE);
|
||||
},
|
||||
ListMemoryScheduler));
|
||||
ScheduleModule(
|
||||
module.get(),
|
||||
[](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||
TUPLE_SIZE);
|
||||
},
|
||||
ListMemoryScheduler));
|
||||
|
||||
// Verify that all instructions are in the sequence.
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||
@ -225,12 +226,12 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
|
||||
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
|
||||
ScheduleModule(*module,
|
||||
[](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(
|
||||
buffer.shape(), 2);
|
||||
},
|
||||
ListMemoryScheduler));
|
||||
ScheduleModule(
|
||||
module.get(),
|
||||
[](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
|
||||
},
|
||||
ListMemoryScheduler));
|
||||
|
||||
// Verify that all instructions are in the sequence.
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||
@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
||||
};
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, size_fn, ListMemoryScheduler));
|
||||
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
|
||||
// Verify that all instructions are in the sequence.
|
||||
auto entry_computation = module->entry_computation();
|
||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||
|
@ -104,11 +104,7 @@ class HloModule {
|
||||
HloCloneContext* context = nullptr);
|
||||
|
||||
// Return a pointer to the entry computation of the module.
|
||||
const HloComputation* entry_computation() const {
|
||||
CHECK_NE(nullptr, entry_computation_);
|
||||
return entry_computation_;
|
||||
}
|
||||
HloComputation* entry_computation() {
|
||||
HloComputation* entry_computation() const {
|
||||
CHECK_NE(nullptr, entry_computation_);
|
||||
return entry_computation_;
|
||||
}
|
||||
|
@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() {
|
||||
// Create a map from instruction to its order position.
|
||||
TF_DCHECK_OK(schedule_.Verify());
|
||||
for (const auto& computation_sequence : schedule_.sequences()) {
|
||||
const std::vector<const HloInstruction*>& order =
|
||||
computation_sequence.second.instructions();
|
||||
const auto& order = computation_sequence.second.instructions();
|
||||
for (int i = 0; i < order.size(); ++i) {
|
||||
InsertOrDie(&order_position_, order[i], i);
|
||||
}
|
||||
|
@ -47,11 +47,11 @@ const double kF16max = 65504;
|
||||
|
||||
// Creates and returns a schedule created using the order of the instructions in
|
||||
// the HloComputation::instructions() vectors in the module.
|
||||
HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
|
||||
HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
|
||||
HloSchedule schedule(module);
|
||||
for (const HloComputation* computation : module->computations()) {
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
if (!computation->IsFusionComputation()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
schedule.GetOrCreateSequence(computation).push_back(instruction);
|
||||
}
|
||||
}
|
||||
|
@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector<Item*, 3>;
|
||||
// before arbitrary elements.
|
||||
class InstructionList {
|
||||
public:
|
||||
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
|
||||
explicit InstructionList(const HloInstructionSequence& order) {
|
||||
int64 position = 0;
|
||||
Item* last = nullptr;
|
||||
for (const HloInstruction* inst : order) {
|
||||
for (HloInstruction* inst : order.instructions()) {
|
||||
// Add a new item to the linked list.
|
||||
Item* item = new Item;
|
||||
item->next = nullptr;
|
||||
@ -151,7 +151,7 @@ class InstructionList {
|
||||
// to be monotonically increasing through the list, and so is still useful
|
||||
// for quickly(-ish) determining the order of arbitrary instructions in
|
||||
// the list.
|
||||
item->instruction = const_cast<HloInstruction*>(inst);
|
||||
item->instruction = inst;
|
||||
item->position = position;
|
||||
position++;
|
||||
|
||||
@ -927,7 +927,7 @@ Item* PickRematerializationCandidate(
|
||||
|
||||
StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
||||
const HloComputation* computation,
|
||||
const std::vector<const HloInstruction*>& order) const {
|
||||
const HloInstructionSequence& order) const {
|
||||
InstructionList instruction_list(order);
|
||||
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
||||
instruction_list);
|
||||
@ -971,8 +971,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
|
||||
CHECK(!ContainsKey(rematerialized_computations_, computation));
|
||||
|
||||
InstructionList instruction_list(
|
||||
schedule->sequence(computation).instructions());
|
||||
InstructionList instruction_list(schedule->sequence(computation));
|
||||
MemoryUsageTracker memory_tracker(computation, size_function_,
|
||||
*points_to_analysis_, instruction_list);
|
||||
bool changed = false;
|
||||
@ -1184,7 +1183,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
sequence.clear();
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
item = instruction_list.next(item)) {
|
||||
const HloInstruction* instruction = item->instruction;
|
||||
HloInstruction* instruction = item->instruction;
|
||||
sequence.push_back(instruction);
|
||||
}
|
||||
rematerialized_computations_.insert(computation);
|
||||
@ -1235,10 +1234,8 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
|
||||
if (node.context() == CallContext::kSequential) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
computation_peak_memory_[node.computation()],
|
||||
ComputePeakMemory(node.computation(),
|
||||
module->schedule()
|
||||
.sequence(node.computation())
|
||||
.instructions()));
|
||||
ComputePeakMemory(node.computation(), module->schedule().sequence(
|
||||
node.computation())));
|
||||
}
|
||||
return Status::OK();
|
||||
},
|
||||
|
@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass {
|
||||
// peak memory is the maximum total size of all live HLO instruction values at
|
||||
// any program point. 'order' is the order in which the HLO instructions will
|
||||
// be emitted which is used to determine lifespans of HLO values.
|
||||
StatusOr<int64> ComputePeakMemory(
|
||||
const HloComputation* computation,
|
||||
const std::vector<const HloInstruction*>& order) const;
|
||||
StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
|
||||
const HloInstructionSequence& order) const;
|
||||
|
||||
// Returns the peak memory usage of the called computations for the given
|
||||
// instruction. Zero is returned if the instruction calls no computations.
|
||||
|
@ -46,8 +46,8 @@ namespace xla {
|
||||
<< "No computation exists in HLO module with id " << computation_id;
|
||||
const HloComputation* computation = comp_it->second;
|
||||
|
||||
absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction;
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
absl::flat_hash_map<int64, HloInstruction*> id_to_instruction;
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
id_to_instruction[instruction->unique_id()] = instruction;
|
||||
}
|
||||
|
||||
@ -81,9 +81,8 @@ StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
|
||||
return std::move(proto);
|
||||
}
|
||||
|
||||
void HloSchedule::set_sequence(
|
||||
const HloComputation* computation,
|
||||
absl::Span<const HloInstruction* const> sequence) {
|
||||
void HloSchedule::set_sequence(const HloComputation* computation,
|
||||
absl::Span<HloInstruction* const> sequence) {
|
||||
set_sequence(computation, HloInstructionSequence(sequence));
|
||||
}
|
||||
|
||||
@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule(
|
||||
const HloComputation* computation) {
|
||||
// Map from unique ID to HloInstruction pointer for instructions in the
|
||||
// computation.
|
||||
absl::flat_hash_map<int, const HloInstruction*> id_to_instruction;
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
|
||||
}
|
||||
|
||||
@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule(
|
||||
// Map from HloInstruction X to newly added instructions (instruction is in
|
||||
// computation, but not in schedule) which use X. If an instruction is not in
|
||||
// the map, then it has no users which are newly added instructions.
|
||||
absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>>
|
||||
absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
|
||||
new_instruction_uses;
|
||||
|
||||
// For each newly added instruction, this is the count of the instruction's
|
||||
@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule(
|
||||
|
||||
// Create a worklist of newly added instructions which are ready to be added
|
||||
// to the schedule. Initialize worklist with those that have zero operands.
|
||||
std::queue<const HloInstruction*> worklist;
|
||||
std::queue<HloInstruction*> worklist;
|
||||
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (ids_in_schedule.count(instruction->unique_id()) == 0) {
|
||||
// This is a newly added instruction which is not in the schedule.
|
||||
if (instruction->operands().empty()) {
|
||||
@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule(
|
||||
// Lambda which schedules all instructions on the worklist.
|
||||
auto schedule_worklist = [&]() {
|
||||
while (!worklist.empty()) {
|
||||
const HloInstruction* instruction = worklist.front();
|
||||
HloInstruction* instruction = worklist.front();
|
||||
worklist.pop();
|
||||
new_sequence.push_back(instruction);
|
||||
std::vector<const HloInstruction*>* new_users =
|
||||
std::vector<HloInstruction*>* new_users =
|
||||
tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
|
||||
if (new_users != nullptr) {
|
||||
// This just-scheduled instruction has users which are newly added to
|
||||
// the module. Update the number of unscheduled operands and push the
|
||||
// newly added instruction to the worklist if it is ready to
|
||||
// schedule.
|
||||
for (const HloInstruction* new_user : *new_users) {
|
||||
for (HloInstruction* new_user : *new_users) {
|
||||
unscheduled_operand_count.at(new_user)--;
|
||||
CHECK_GE(unscheduled_operand_count.at(new_user), 0);
|
||||
if (unscheduled_operand_count.at(new_user) == 0) {
|
||||
|
@ -35,14 +35,14 @@ class HloInstructionSequence {
|
||||
public:
|
||||
HloInstructionSequence() = default;
|
||||
explicit HloInstructionSequence(
|
||||
absl::Span<const HloInstruction* const> instructions) {
|
||||
for (const HloInstruction* instruction : instructions) {
|
||||
absl::Span<HloInstruction* const> instructions) {
|
||||
for (HloInstruction* instruction : instructions) {
|
||||
push_back(instruction);
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the instruction to the end of the sequence.
|
||||
void push_back(const HloInstruction* instruction) {
|
||||
void push_back(HloInstruction* instruction) {
|
||||
instruction_sequence_.push_back(instruction);
|
||||
id_sequence_.push_back(instruction->unique_id());
|
||||
}
|
||||
@ -56,7 +56,7 @@ class HloInstructionSequence {
|
||||
int64 size() const { return instruction_sequence_.size(); }
|
||||
|
||||
// Returns the sequence of HLO instructions.
|
||||
const std::vector<const HloInstruction*>& instructions() const {
|
||||
const std::vector<HloInstruction*>& instructions() const {
|
||||
return instruction_sequence_;
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ class HloInstructionSequence {
|
||||
|
||||
private:
|
||||
// The sequence as HloInstructions.
|
||||
std::vector<const HloInstruction*> instruction_sequence_;
|
||||
std::vector<HloInstruction*> instruction_sequence_;
|
||||
|
||||
// The sequence of HLO instructions, represented by their unique IDs. The
|
||||
// sequence is stored as both HloInstructions and unique IDs because the
|
||||
@ -98,7 +98,7 @@ class HloSchedule {
|
||||
|
||||
// Sets the sequence for the given computation to the given sequence.
|
||||
void set_sequence(const HloComputation* computation,
|
||||
absl::Span<const HloInstruction* const> sequence);
|
||||
absl::Span<HloInstruction* const> sequence);
|
||||
void set_sequence(const HloComputation* computation,
|
||||
HloInstructionSequence sequence);
|
||||
|
||||
|
@ -56,10 +56,10 @@ ENTRY main {
|
||||
ParseHloString(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||
}));
|
||||
const std::vector<const HloInstruction*>& entry_schedule =
|
||||
const auto& entry_schedule =
|
||||
schedule.sequence(module->entry_computation()).instructions();
|
||||
|
||||
EXPECT_EQ(entry_schedule.size(), 6);
|
||||
@ -90,7 +90,7 @@ ENTRY main {
|
||||
ParseHloString(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||
}));
|
||||
|
||||
@ -139,7 +139,7 @@ ENTRY main {
|
||||
ParseHloString(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||
}));
|
||||
|
||||
@ -183,7 +183,7 @@ ENTRY main {
|
||||
ParseHloString(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||
}));
|
||||
|
||||
@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] {
|
||||
ParseHloString(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||
/*pointer_size=*/sizeof(void*));
|
||||
}));
|
||||
@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] {
|
||||
ParseHloString(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||
/*pointer_size=*/sizeof(void*));
|
||||
}));
|
||||
|
Loading…
Reference in New Issue
Block a user