[TF:XLA] Remove a const_cast from HLO rematerialization.
This required making HloSchedule and related classes a sequence of HloInstruction* (instead of const HloInstruction*) Also, use HloInstructionSequence in a few more places (hlo_rematerialization.cc) No functional change. PiperOrigin-RevId: 221868639
This commit is contained in:
parent
8930d5aff5
commit
2121365e74
@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
|
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
|
||||||
HloModule* module,
|
HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
|
||||||
absl::Span<const HloInstruction* const> instruction_sequence,
|
|
||||||
int64 alignment = 1) {
|
int64 alignment = 1) {
|
||||||
HloSchedule schedule(module);
|
HloSchedule schedule(module);
|
||||||
schedule.set_sequence(module->entry_computation(), instruction_sequence);
|
schedule.set_sequence(module->entry_computation(), instruction_sequence);
|
||||||
@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloTestBase {
|
|||||||
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
|
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
|
||||||
int64 alignment = 1) {
|
int64 alignment = 1) {
|
||||||
HloSchedule schedule =
|
HloSchedule schedule =
|
||||||
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
|
ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
|
||||||
return BufferAssigner::Run(
|
return BufferAssigner::Run(
|
||||||
module, absl::make_unique<SequentialHloOrdering>(schedule),
|
module, absl::make_unique<SequentialHloOrdering>(schedule),
|
||||||
ByteSizeOf,
|
ByteSizeOf,
|
||||||
@ -2162,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
|||||||
// nodes are traversed during BufferAssignment.
|
// nodes are traversed during BufferAssignment.
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||||
/*pointer_size=*/sizeof(void*));
|
/*pointer_size=*/sizeof(void*));
|
||||||
}));
|
}));
|
||||||
@ -2391,15 +2390,16 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
|
|||||||
RunCopyInsertion(module.get());
|
RunCopyInsertion(module.get());
|
||||||
|
|
||||||
HloSchedule schedule =
|
HloSchedule schedule =
|
||||||
ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
|
ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
|
||||||
|
|
||||||
// To trigger b/38494731, we want a specific Hlo schedule for the
|
// To trigger b/38494731, we want a specific Hlo schedule for the
|
||||||
// root computation, so we overwrite that entry with a manually
|
// root computation, so we overwrite that entry with a manually
|
||||||
// crafted sequence.
|
// crafted sequence.
|
||||||
schedule.set_sequence(module->entry_computation(),
|
schedule.set_sequence(
|
||||||
{input1, weights1, one, output1, while1->operand(0),
|
module->entry_computation(),
|
||||||
while1, input0, weights0, zero, output0,
|
{input1, weights1, one, output1, while1->mutable_operand(0), while1,
|
||||||
while0->operand(0), while0, gte0, gte1, root_add});
|
input0, weights0, zero, output0, while0->mutable_operand(0), while0,
|
||||||
|
gte0, gte1, root_add});
|
||||||
|
|
||||||
// If this ASSERT fails, we constructed a bogus sequence above and this test
|
// If this ASSERT fails, we constructed a bogus sequence above and this test
|
||||||
// itself is buggy.
|
// itself is buggy.
|
||||||
|
@ -587,9 +587,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
// Select an order for emitting the HLO instructions for each
|
// Select an order for emitting the HLO instructions for each
|
||||||
// computation. Using this sequence enables tighter buffer liveness analysis
|
// computation. Using this sequence enables tighter buffer liveness analysis
|
||||||
// and reduced memory usage (as compared to using DependencyHloOrdering).
|
// and reduced memory usage (as compared to using DependencyHloOrdering).
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
||||||
HloSchedule schedule,
|
ScheduleModule(module.get(), BufferSizeBytesFunction(),
|
||||||
ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
|
DFSMemoryScheduler));
|
||||||
|
|
||||||
// Run buffer allocation on the HLO graph.
|
// Run buffer allocation on the HLO graph.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -779,7 +779,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
|||||||
XLA_VLOG_LINES(2, module->ToString());
|
XLA_VLOG_LINES(2, module->ToString());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
||||||
ScheduleModule(*module, BufferSizeBytesFunction()));
|
ScheduleModule(module, BufferSizeBytesFunction()));
|
||||||
|
|
||||||
// Run buffer analysis on the HLO graph. This analysis figures out which
|
// Run buffer analysis on the HLO graph. This analysis figures out which
|
||||||
// temporary buffers are required to run the computation.
|
// temporary buffers are required to run the computation.
|
||||||
|
@ -111,7 +111,7 @@ IrEmitter::IrEmitter(
|
|||||||
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
||||||
HloComputation* computation, const string& function_name_prefix,
|
HloComputation* computation, const string& function_name_prefix,
|
||||||
bool is_top_level_computation,
|
bool is_top_level_computation,
|
||||||
const std::vector<const HloInstruction*>* instruction_order) {
|
const std::vector<HloInstruction*>* instruction_order) {
|
||||||
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
|
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
|
||||||
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
|
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
|
||||||
<< "]; ordered? " << (instruction_order != nullptr);
|
<< "]; ordered? " << (instruction_order != nullptr);
|
||||||
|
@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
StatusOr<llvm::Function*> EmitComputation(
|
StatusOr<llvm::Function*> EmitComputation(
|
||||||
HloComputation* computation, const string& function_name_prefix,
|
HloComputation* computation, const string& function_name_prefix,
|
||||||
bool is_top_level_computation,
|
bool is_top_level_computation,
|
||||||
const std::vector<const HloInstruction*>* instruction_order);
|
const std::vector<HloInstruction*>* instruction_order);
|
||||||
|
|
||||||
llvm::IRBuilder<>* b() { return &b_; }
|
llvm::IRBuilder<>* b() { return &b_; }
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
|
|||||||
public:
|
public:
|
||||||
GpuHloOrdering(const HloModule* module,
|
GpuHloOrdering(const HloModule* module,
|
||||||
const StreamAssignment& stream_assignment,
|
const StreamAssignment& stream_assignment,
|
||||||
const std::vector<const HloInstruction*>& thunk_launch_order);
|
const std::vector<HloInstruction*>& thunk_launch_order);
|
||||||
~GpuHloOrdering() override = default;
|
~GpuHloOrdering() override = default;
|
||||||
|
|
||||||
// Only the entry computation can possibly be sequentially ordered, and only
|
// Only the entry computation can possibly be sequentially ordered, and only
|
||||||
@ -56,7 +56,7 @@ class GpuHloOrdering : public PredecessorHloOrdering {
|
|||||||
|
|
||||||
GpuHloOrdering::GpuHloOrdering(
|
GpuHloOrdering::GpuHloOrdering(
|
||||||
const HloModule* module, const StreamAssignment& stream_assignment,
|
const HloModule* module, const StreamAssignment& stream_assignment,
|
||||||
const std::vector<const HloInstruction*>& thunk_launch_order)
|
const std::vector<HloInstruction*>& thunk_launch_order)
|
||||||
: PredecessorHloOrdering(module) {
|
: PredecessorHloOrdering(module) {
|
||||||
// The entry computation has a total order when there's only one stream.
|
// The entry computation has a total order when there's only one stream.
|
||||||
if (stream_assignment.StreamCount() == 1) {
|
if (stream_assignment.StreamCount() == 1) {
|
||||||
@ -150,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering(
|
|||||||
// However, if the total order is A,B,D,C,E, then C and E can run
|
// However, if the total order is A,B,D,C,E, then C and E can run
|
||||||
// concurrently.
|
// concurrently.
|
||||||
void BFSLaunchOrder(const HloComputation* computation,
|
void BFSLaunchOrder(const HloComputation* computation,
|
||||||
std::vector<const HloInstruction*>* launch_order) {
|
std::vector<HloInstruction*>* launch_order) {
|
||||||
// This topological sort uses two data structures:
|
// This topological sort uses two data structures:
|
||||||
// 1. `incoming_edge_count` which keeps track of the number of incoming
|
// 1. `incoming_edge_count` which keeps track of the number of incoming
|
||||||
// edges to each HLO;
|
// edges to each HLO;
|
||||||
@ -158,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation,
|
|||||||
//
|
//
|
||||||
// The sorting algorithm repeatedly pops the top from the queue and deletes
|
// The sorting algorithm repeatedly pops the top from the queue and deletes
|
||||||
// that HLO from the graph, making more HLOs incoming-edge free.
|
// that HLO from the graph, making more HLOs incoming-edge free.
|
||||||
std::deque<const HloInstruction*> queue;
|
std::deque<HloInstruction*> queue;
|
||||||
std::unordered_map<const HloInstruction*, int64> incoming_edge_count;
|
std::unordered_map<const HloInstruction*, int64> incoming_edge_count;
|
||||||
for (const auto& hlo : computation->instructions()) {
|
for (auto* hlo : computation->instructions()) {
|
||||||
if (hlo->operand_count() == 0) {
|
if (hlo->operand_count() == 0) {
|
||||||
queue.push_back(hlo);
|
queue.push_back(hlo);
|
||||||
} else {
|
} else {
|
||||||
@ -172,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
while (!queue.empty()) {
|
while (!queue.empty()) {
|
||||||
const HloInstruction* x = queue.front();
|
HloInstruction* x = queue.front();
|
||||||
queue.pop_front();
|
queue.pop_front();
|
||||||
launch_order->push_back(x);
|
launch_order->push_back(x);
|
||||||
for (const HloInstruction* y : x->users()) {
|
for (HloInstruction* y : x->users()) {
|
||||||
--incoming_edge_count[y];
|
--incoming_edge_count[y];
|
||||||
if (incoming_edge_count[y] == 0) {
|
if (incoming_edge_count[y] == 0) {
|
||||||
queue.push_back(y);
|
queue.push_back(y);
|
||||||
@ -195,14 +195,14 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
|
|||||||
std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
|
std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
|
||||||
|
|
||||||
// Initialize thunk_launch_order_, the total order of thunk launches.
|
// Initialize thunk_launch_order_, the total order of thunk launches.
|
||||||
const HloComputation* entry_computation = module.entry_computation();
|
HloComputation* entry_computation = module.entry_computation();
|
||||||
if (stream_assignment.StreamCount() == 1) {
|
if (stream_assignment.StreamCount() == 1) {
|
||||||
// All kernels are launched on a single stream, so there's no loss of
|
// All kernels are launched on a single stream, so there's no loss of
|
||||||
// concurrency by optimizing for minimal memory usage.
|
// concurrency by optimizing for minimal memory usage.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloInstructionSequence sequence,
|
HloInstructionSequence sequence,
|
||||||
ScheduleComputation(
|
ScheduleComputation(
|
||||||
*entry_computation, [pointer_size](const BufferValue& buffer) {
|
entry_computation, [pointer_size](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
|
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
|
||||||
}));
|
}));
|
||||||
schedule->thunk_launch_order_ = sequence.instructions();
|
schedule->thunk_launch_order_ = sequence.instructions();
|
||||||
|
@ -46,7 +46,7 @@ class GpuHloSchedule {
|
|||||||
|
|
||||||
// Returns the total order of thunk launches, represented in terms of HLO
|
// Returns the total order of thunk launches, represented in terms of HLO
|
||||||
// instructions.
|
// instructions.
|
||||||
const std::vector<const HloInstruction*>& ThunkLaunchOrder() const {
|
const std::vector<HloInstruction*>& ThunkLaunchOrder() const {
|
||||||
return thunk_launch_order_;
|
return thunk_launch_order_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class GpuHloSchedule {
|
|||||||
private:
|
private:
|
||||||
GpuHloSchedule();
|
GpuHloSchedule();
|
||||||
|
|
||||||
std::vector<const HloInstruction*> thunk_launch_order_;
|
std::vector<HloInstruction*> thunk_launch_order_;
|
||||||
std::unique_ptr<HloOrdering> hlo_ordering_;
|
std::unique_ptr<HloOrdering> hlo_ordering_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ namespace gpu {
|
|||||||
|
|
||||||
class GpuHloScheduleTest : public HloTestBase {
|
class GpuHloScheduleTest : public HloTestBase {
|
||||||
protected:
|
protected:
|
||||||
using HloVec = std::vector<const HloInstruction*>;
|
using HloVec = std::vector<HloInstruction*>;
|
||||||
|
|
||||||
// Pre-canned shapes.
|
// Pre-canned shapes.
|
||||||
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
|
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
|
||||||
|
@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands(
|
|||||||
ThunkSchedule::ThunkSchedule(
|
ThunkSchedule::ThunkSchedule(
|
||||||
std::unique_ptr<ThunkSequence> thunks,
|
std::unique_ptr<ThunkSequence> thunks,
|
||||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
std::unique_ptr<StreamAssignment> stream_assignment,
|
||||||
const std::vector<const HloInstruction*>& hlo_total_order)
|
const std::vector<HloInstruction*>& hlo_total_order)
|
||||||
: thunks_(std::move(thunks)),
|
: thunks_(std::move(thunks)),
|
||||||
stream_assignment_(std::move(stream_assignment)) {
|
stream_assignment_(std::move(stream_assignment)) {
|
||||||
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
||||||
@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule(
|
|||||||
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
|
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const HloInstruction* hlo : hlo_total_order) {
|
for (HloInstruction* hlo : hlo_total_order) {
|
||||||
if (hlo_to_thunk.count(hlo)) {
|
if (hlo_to_thunk.count(hlo)) {
|
||||||
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
|
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ class ThunkSchedule {
|
|||||||
public:
|
public:
|
||||||
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
|
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
|
||||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
std::unique_ptr<StreamAssignment> stream_assignment,
|
||||||
const std::vector<const HloInstruction*>& hlo_total_order);
|
const std::vector<HloInstruction*>& hlo_total_order);
|
||||||
|
|
||||||
// Returns the total order of executing all the thunks.
|
// Returns the total order of executing all the thunks.
|
||||||
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
|
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
|
||||||
|
@ -258,7 +258,7 @@ class HeapSimulatorTracker {
|
|||||||
// Constructor for testing a single entry computation.
|
// Constructor for testing a single entry computation.
|
||||||
HeapSimulatorTracker(
|
HeapSimulatorTracker(
|
||||||
const string& name, std::unique_ptr<HloComputation> computation,
|
const string& name, std::unique_ptr<HloComputation> computation,
|
||||||
const std::vector<const HloInstruction*>& instruction_sequence) {
|
const std::vector<HloInstruction*>& instruction_sequence) {
|
||||||
HloModuleConfig config;
|
HloModuleConfig config;
|
||||||
module_ = absl::make_unique<HloModule>(name, config);
|
module_ = absl::make_unique<HloModule>(name, config);
|
||||||
module_->AddEntryComputation(std::move(computation));
|
module_->AddEntryComputation(std::move(computation));
|
||||||
@ -286,7 +286,7 @@ class HeapSimulatorTracker {
|
|||||||
// Similar to the single entry computation constructor above, but runs the
|
// Similar to the single entry computation constructor above, but runs the
|
||||||
// simulation over the entire module.
|
// simulation over the entire module.
|
||||||
void RunWholeModule(
|
void RunWholeModule(
|
||||||
const std::vector<const HloInstruction*>& full_module_sequence) {
|
const std::vector<HloInstruction*>& full_module_sequence) {
|
||||||
points_to_analysis_ =
|
points_to_analysis_ =
|
||||||
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||||
|
|
||||||
@ -294,7 +294,7 @@ class HeapSimulatorTracker {
|
|||||||
HloSchedule schedule(module_.get());
|
HloSchedule schedule(module_.get());
|
||||||
absl::flat_hash_map<const HloInstruction*, int> reverse_position;
|
absl::flat_hash_map<const HloInstruction*, int> reverse_position;
|
||||||
for (int i = 0; i < full_module_sequence.size(); ++i) {
|
for (int i = 0; i < full_module_sequence.size(); ++i) {
|
||||||
const HloInstruction* instruction = full_module_sequence[i];
|
HloInstruction* instruction = full_module_sequence[i];
|
||||||
schedule.GetOrCreateSequence(instruction->parent())
|
schedule.GetOrCreateSequence(instruction->parent())
|
||||||
.push_back(instruction);
|
.push_back(instruction);
|
||||||
reverse_position[instruction] = full_module_sequence.size() - i;
|
reverse_position[instruction] = full_module_sequence.size() - i;
|
||||||
|
@ -795,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder(
|
|||||||
template <typename HloInstructionPtr>
|
template <typename HloInstructionPtr>
|
||||||
Status HloComputation::AcceptOrdered(
|
Status HloComputation::AcceptOrdered(
|
||||||
DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
||||||
const std::vector<const HloInstruction*>& order) const {
|
const std::vector<HloInstruction*>& order) const {
|
||||||
VLOG(3) << "Accepting visitor with order.";
|
VLOG(3) << "Accepting visitor with order.";
|
||||||
for (HloInstruction* root : CollectUnreachableRoots()) {
|
for (HloInstruction* root : CollectUnreachableRoots()) {
|
||||||
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
|
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
|
||||||
@ -825,9 +825,9 @@ Status HloComputation::AcceptOrdered(
|
|||||||
|
|
||||||
// Explicit instantiations.
|
// Explicit instantiations.
|
||||||
template Status HloComputation::AcceptOrdered(
|
template Status HloComputation::AcceptOrdered(
|
||||||
DfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
|
DfsHloVisitor*, const std::vector<HloInstruction*>&) const;
|
||||||
template Status HloComputation::AcceptOrdered(
|
template Status HloComputation::AcceptOrdered(
|
||||||
ConstDfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
|
ConstDfsHloVisitor*, const std::vector<HloInstruction*>&) const;
|
||||||
|
|
||||||
Status HloComputation::Accept(
|
Status HloComputation::Accept(
|
||||||
const std::function<Status(HloInstruction*)>& visitor_func) {
|
const std::function<Status(HloInstruction*)>& visitor_func) {
|
||||||
|
@ -301,7 +301,7 @@ class HloComputation {
|
|||||||
// be a topological sort of all instructions in the computation.
|
// be a topological sort of all instructions in the computation.
|
||||||
template <typename HloInstructionPtr>
|
template <typename HloInstructionPtr>
|
||||||
Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
|
||||||
const std::vector<const HloInstruction*>& order) const;
|
const std::vector<HloInstruction*>& order) const;
|
||||||
|
|
||||||
// Same as Accept() above, but the visitor is given as a function.
|
// Same as Accept() above, but the visitor is given as a function.
|
||||||
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
|
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
|
||||||
|
@ -73,7 +73,7 @@ class ListScheduler {
|
|||||||
// Construct and return a memory-minimizing sequence of HLO instructions
|
// Construct and return a memory-minimizing sequence of HLO instructions
|
||||||
// containing the given HLO computation.
|
// containing the given HLO computation.
|
||||||
static StatusOr<HloInstructionSequence> Run(
|
static StatusOr<HloInstructionSequence> Run(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -98,7 +98,7 @@ class ListScheduler {
|
|||||||
// comparison operators.
|
// comparison operators.
|
||||||
using Priority = std::pair<int64, int64>;
|
using Priority = std::pair<int64, int64>;
|
||||||
|
|
||||||
ListScheduler(const HloComputation& computation,
|
ListScheduler(HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -111,7 +111,7 @@ class ListScheduler {
|
|||||||
// instruction. An HLO instruction "uses" a LogicalBuffer if the
|
// instruction. An HLO instruction "uses" a LogicalBuffer if the
|
||||||
// LogicalBuffer is in an operand of the instruction as indicated by
|
// LogicalBuffer is in an operand of the instruction as indicated by
|
||||||
// points-to analysis.
|
// points-to analysis.
|
||||||
for (auto* instruction : computation.instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
|
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
|
||||||
for (auto* operand : instruction->operands()) {
|
for (auto* operand : instruction->operands()) {
|
||||||
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
||||||
@ -126,13 +126,13 @@ class ListScheduler {
|
|||||||
|
|
||||||
// Create map containing the number of unscheduled uses (hlo instructions)
|
// Create map containing the number of unscheduled uses (hlo instructions)
|
||||||
// of each logical buffer.
|
// of each logical buffer.
|
||||||
for (auto* instruction : computation.instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
for (auto* buffer :
|
for (auto* buffer :
|
||||||
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
||||||
unscheduled_use_count_[buffer] = 0;
|
unscheduled_use_count_[buffer] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto* instruction : computation.instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
|
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
|
||||||
++unscheduled_use_count_[buffer];
|
++unscheduled_use_count_[buffer];
|
||||||
}
|
}
|
||||||
@ -141,7 +141,7 @@ class ListScheduler {
|
|||||||
// Buffers live out of the computation have an implicit use at the end of
|
// Buffers live out of the computation have an implicit use at the end of
|
||||||
// the computation.
|
// the computation.
|
||||||
for (const LogicalBuffer* live_out_buffer :
|
for (const LogicalBuffer* live_out_buffer :
|
||||||
points_to_analysis.GetPointsToSet(computation.root_instruction())
|
points_to_analysis.GetPointsToSet(computation->root_instruction())
|
||||||
.CreateFlattenedSet()) {
|
.CreateFlattenedSet()) {
|
||||||
++unscheduled_use_count_[live_out_buffer];
|
++unscheduled_use_count_[live_out_buffer];
|
||||||
}
|
}
|
||||||
@ -157,7 +157,7 @@ class ListScheduler {
|
|||||||
// HloInstruction, plus some cached metadata, saved for the purposes of making
|
// HloInstruction, plus some cached metadata, saved for the purposes of making
|
||||||
// BytesFreedIfScheduled fast.
|
// BytesFreedIfScheduled fast.
|
||||||
struct ReadyListEntry {
|
struct ReadyListEntry {
|
||||||
const HloInstruction* instruction;
|
HloInstruction* instruction;
|
||||||
|
|
||||||
// The total size of all buffers defined by this instruction.
|
// The total size of all buffers defined by this instruction.
|
||||||
int64 bytes_defined;
|
int64 bytes_defined;
|
||||||
@ -171,7 +171,7 @@ class ListScheduler {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Creates a ReadyListEntry for the given instruction.
|
// Creates a ReadyListEntry for the given instruction.
|
||||||
ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) {
|
ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
|
||||||
ReadyListEntry entry;
|
ReadyListEntry entry;
|
||||||
entry.instruction = instruction;
|
entry.instruction = instruction;
|
||||||
|
|
||||||
@ -250,13 +250,13 @@ class ListScheduler {
|
|||||||
// Populate the ready list with instructions which have no operands or
|
// Populate the ready list with instructions which have no operands or
|
||||||
// control predecessors.
|
// control predecessors.
|
||||||
absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
|
absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
|
||||||
for (auto* instruction : computation_.instructions()) {
|
for (auto* instruction : computation_->instructions()) {
|
||||||
// TODO(b/34466113): Replace this and above with successors() or
|
// TODO(b/34466113): Replace this and above with successors() or
|
||||||
// predecessors() when these methods are added to HloInstruction.
|
// predecessors() when these methods are added to HloInstruction.
|
||||||
for (const HloInstruction* user : instruction->users()) {
|
for (HloInstruction* user : instruction->users()) {
|
||||||
unscheduled_pred_count[user]++;
|
unscheduled_pred_count[user]++;
|
||||||
}
|
}
|
||||||
for (const HloInstruction* succ : instruction->control_successors()) {
|
for (HloInstruction* succ : instruction->control_successors()) {
|
||||||
unscheduled_pred_count[succ]++;
|
unscheduled_pred_count[succ]++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,7 +275,7 @@ class ListScheduler {
|
|||||||
ready_instructions[inst] = it;
|
ready_instructions[inst] = it;
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto* instruction : computation_.instructions()) {
|
for (auto* instruction : computation_->instructions()) {
|
||||||
if (instruction->operands().empty() &&
|
if (instruction->operands().empty() &&
|
||||||
instruction->control_predecessors().empty()) {
|
instruction->control_predecessors().empty()) {
|
||||||
add_to_ready_queue(instruction);
|
add_to_ready_queue(instruction);
|
||||||
@ -287,7 +287,7 @@ class ListScheduler {
|
|||||||
// schedule.
|
// schedule.
|
||||||
auto best_it = ready_queue.end();
|
auto best_it = ready_queue.end();
|
||||||
--best_it;
|
--best_it;
|
||||||
const HloInstruction* best = best_it->second.instruction;
|
HloInstruction* best = best_it->second.instruction;
|
||||||
VLOG(2) << "Schedule instruction: " << best->ToShortString()
|
VLOG(2) << "Schedule instruction: " << best->ToShortString()
|
||||||
<< " Bytes freed: " << best_it->first.first;
|
<< " Bytes freed: " << best_it->first.first;
|
||||||
ready_queue.erase(best_it);
|
ready_queue.erase(best_it);
|
||||||
@ -348,13 +348,13 @@ class ListScheduler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_EQ(schedule.size(), computation_.instruction_count());
|
CHECK_EQ(schedule.size(), computation_->instruction_count());
|
||||||
CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
|
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
|
||||||
|
|
||||||
return schedule;
|
return schedule;
|
||||||
}
|
}
|
||||||
|
|
||||||
const HloComputation& computation_;
|
HloComputation* computation_;
|
||||||
const TuplePointsToAnalysis& points_to_analysis_;
|
const TuplePointsToAnalysis& points_to_analysis_;
|
||||||
const LogicalBuffer::SizeFunction& size_function_;
|
const LogicalBuffer::SizeFunction& size_function_;
|
||||||
// Computations are analyzed in post-order. When scheduling an instruction
|
// Computations are analyzed in post-order. When scheduling an instruction
|
||||||
@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloInstructionSequence> ScheduleComputationHelper(
|
StatusOr<HloInstructionSequence> ScheduleComputationHelper(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const MemorySchedulerAlgorithm& algorithm,
|
const MemorySchedulerAlgorithm& algorithm,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
VLOG(2) << "Computation: " << computation.name();
|
VLOG(2) << "Computation: " << computation->name();
|
||||||
if (algorithm) {
|
if (algorithm) {
|
||||||
return algorithm(computation, points_to_analysis, size_function,
|
return algorithm(computation, points_to_analysis, size_function,
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
@ -404,17 +404,17 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
// These variables are a hack to prevent overflows.
|
// These variables are a hack to prevent overflows.
|
||||||
int64 cumulative_total_size = 0;
|
int64 cumulative_total_size = 0;
|
||||||
int64 total_hlos = computation.parent()->instruction_count();
|
int64 total_hlos = computation->parent()->instruction_count();
|
||||||
absl::flat_hash_map<const HloInstruction*, int64> extra_users;
|
absl::flat_hash_map<const HloInstruction*, int64> extra_users;
|
||||||
absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
|
absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
|
||||||
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
|
for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
|
||||||
if (ListScheduler::IgnoreInstruction(*hlo)) {
|
if (ListScheduler::IgnoreInstruction(*hlo)) {
|
||||||
extra_users[hlo] = 0;
|
extra_users[hlo] = 0;
|
||||||
total_sizes[hlo] = 0;
|
total_sizes[hlo] = 0;
|
||||||
@ -448,8 +448,8 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
|||||||
total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
|
total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
|
||||||
extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
|
extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
|
||||||
}
|
}
|
||||||
CHECK_EQ(extra_users.size(), computation.instruction_count());
|
CHECK_EQ(extra_users.size(), computation->instruction_count());
|
||||||
CHECK_EQ(total_sizes.size(), computation.instruction_count());
|
CHECK_EQ(total_sizes.size(), computation->instruction_count());
|
||||||
|
|
||||||
// Construct a total order based on DFS post-order, visiting operands in
|
// Construct a total order based on DFS post-order, visiting operands in
|
||||||
// decreasing cumulative extra user order, and next by cumulative size, with a
|
// decreasing cumulative extra user order, and next by cumulative size, with a
|
||||||
@ -459,7 +459,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
|||||||
sequence.push_back(hlo);
|
sequence.push_back(hlo);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
|
TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
|
||||||
&visitor, [&extra_users, &total_sizes](const HloInstruction* a,
|
&visitor, [&extra_users, &total_sizes](const HloInstruction* a,
|
||||||
const HloInstruction* b) {
|
const HloInstruction* b) {
|
||||||
if (extra_users[a] != extra_users[b]) {
|
if (extra_users[a] != extra_users[b]) {
|
||||||
@ -470,12 +470,12 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
|||||||
}
|
}
|
||||||
return a->name() < b->name();
|
return a->name() < b->name();
|
||||||
}));
|
}));
|
||||||
CHECK_EQ(sequence.size(), computation.instruction_count());
|
CHECK_EQ(sequence.size(), computation->instruction_count());
|
||||||
return sequence;
|
return sequence;
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -485,16 +485,16 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
return HloInstructionSequence(computation.MakeInstructionPostOrder());
|
return HloInstructionSequence(computation->MakeInstructionPostOrder());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -513,7 +513,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
memory_by_computation));
|
memory_by_computation));
|
||||||
TF_ASSIGN_OR_RETURN(const int64 list_memory,
|
TF_ASSIGN_OR_RETURN(const int64 list_memory,
|
||||||
HeapSimulator::MinimumMemoryForComputation(
|
HeapSimulator::MinimumMemoryForComputation(
|
||||||
computation, list_sequence, points_to_analysis,
|
*computation, list_sequence, points_to_analysis,
|
||||||
size_function, &memory_by_computation));
|
size_function, &memory_by_computation));
|
||||||
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
|
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
|
||||||
|
|
||||||
@ -522,7 +522,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
size_function, memory_by_computation));
|
size_function, memory_by_computation));
|
||||||
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
|
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
|
||||||
HeapSimulator::MinimumMemoryForComputation(
|
HeapSimulator::MinimumMemoryForComputation(
|
||||||
computation, dfs_sequence, points_to_analysis,
|
*computation, dfs_sequence, points_to_analysis,
|
||||||
size_function, &memory_by_computation));
|
size_function, &memory_by_computation));
|
||||||
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
|
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
|
||||||
|
|
||||||
@ -532,7 +532,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
memory_by_computation));
|
memory_by_computation));
|
||||||
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
|
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
|
||||||
HeapSimulator::MinimumMemoryForComputation(
|
HeapSimulator::MinimumMemoryForComputation(
|
||||||
computation, post_order_sequence, points_to_analysis,
|
*computation, post_order_sequence, points_to_analysis,
|
||||||
size_function, &memory_by_computation));
|
size_function, &memory_by_computation));
|
||||||
VLOG(2) << "Min-memory post order sequence: "
|
VLOG(2) << "Min-memory post order sequence: "
|
||||||
<< HumanReadableNumBytes(post_order_memory);
|
<< HumanReadableNumBytes(post_order_memory);
|
||||||
@ -555,17 +555,17 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloSchedule> ScheduleModule(
|
StatusOr<HloSchedule> ScheduleModule(
|
||||||
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
|
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
|
||||||
const MemorySchedulerAlgorithm& algorithm) {
|
const MemorySchedulerAlgorithm& algorithm) {
|
||||||
HloSchedule schedule(&module);
|
HloSchedule schedule(module);
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||||
TuplePointsToAnalysis::Run(&module));
|
TuplePointsToAnalysis::Run(module));
|
||||||
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
|
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
|
||||||
for (const auto* computation : module.MakeComputationPostOrder()) {
|
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||||
if (!computation->IsFusionComputation()) {
|
if (!computation->IsFusionComputation()) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
||||||
ScheduleComputationHelper(
|
ScheduleComputationHelper(
|
||||||
*computation, *points_to_analysis, size_function,
|
computation, *points_to_analysis, size_function,
|
||||||
algorithm, memory_by_computation));
|
algorithm, memory_by_computation));
|
||||||
memory_by_computation[computation] =
|
memory_by_computation[computation] =
|
||||||
HeapSimulator::MinimumMemoryForComputation(
|
HeapSimulator::MinimumMemoryForComputation(
|
||||||
@ -583,11 +583,11 @@ StatusOr<HloSchedule> ScheduleModule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloInstructionSequence> ScheduleComputation(
|
StatusOr<HloInstructionSequence> ScheduleComputation(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const LogicalBuffer::SizeFunction& size_function) {
|
const LogicalBuffer::SizeFunction& size_function) {
|
||||||
CHECK(!computation.IsFusionComputation());
|
CHECK(!computation->IsFusionComputation());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||||
TuplePointsToAnalysis::Run(computation.parent()));
|
TuplePointsToAnalysis::Run(computation->parent()));
|
||||||
absl::flat_hash_map<const HloComputation*, int64> empty_map;
|
absl::flat_hash_map<const HloComputation*, int64> empty_map;
|
||||||
return ScheduleComputationHelper(computation, *points_to_analysis,
|
return ScheduleComputationHelper(computation, *points_to_analysis,
|
||||||
size_function, nullptr, empty_map);
|
size_function, nullptr, empty_map);
|
||||||
@ -600,7 +600,7 @@ HloMemoryScheduler::HloMemoryScheduler(
|
|||||||
|
|
||||||
StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
|
StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
|
||||||
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
|
||||||
ScheduleModule(*module, size_function_, algorithm_));
|
ScheduleModule(module, size_function_, algorithm_));
|
||||||
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
|
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -36,14 +36,14 @@ namespace xla {
|
|||||||
// that describes buffer aliasing, together with a target-specific size function
|
// that describes buffer aliasing, together with a target-specific size function
|
||||||
// that maps a tensor's logical size to its padded size.
|
// that maps a tensor's logical size to its padded size.
|
||||||
typedef std::function<StatusOr<HloInstructionSequence>(
|
typedef std::function<StatusOr<HloInstructionSequence>(
|
||||||
const HloComputation&, const TuplePointsToAnalysis&,
|
HloComputation*, const TuplePointsToAnalysis&,
|
||||||
const LogicalBuffer::SizeFunction&,
|
const LogicalBuffer::SizeFunction&,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&)>
|
const absl::flat_hash_map<const HloComputation*, int64>&)>
|
||||||
MemorySchedulerAlgorithm;
|
MemorySchedulerAlgorithm;
|
||||||
|
|
||||||
// List scheduler
|
// List scheduler
|
||||||
StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -51,7 +51,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
|||||||
|
|
||||||
// DFS-order scheduler
|
// DFS-order scheduler
|
||||||
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -59,7 +59,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
|||||||
|
|
||||||
// Naive Post Order scheduler
|
// Naive Post Order scheduler
|
||||||
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -69,7 +69,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
|||||||
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
|
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
|
||||||
// not accounting for fragmentation.
|
// not accounting for fragmentation.
|
||||||
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const absl::flat_hash_map<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
@ -79,13 +79,13 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
// the computation. size_function is the function returning the number of bytes
|
// the computation. size_function is the function returning the number of bytes
|
||||||
// required for a LogicalBuffer.
|
// required for a LogicalBuffer.
|
||||||
StatusOr<HloSchedule> ScheduleModule(
|
StatusOr<HloSchedule> ScheduleModule(
|
||||||
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
|
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
|
||||||
const MemorySchedulerAlgorithm& algorithm = {});
|
const MemorySchedulerAlgorithm& algorithm = {});
|
||||||
|
|
||||||
// Computes the schedule for a single computation.
|
// Computes the schedule for a single computation.
|
||||||
// Currently only used by the GPU backend.
|
// Currently only used by the GPU backend.
|
||||||
StatusOr<HloInstructionSequence> ScheduleComputation(
|
StatusOr<HloInstructionSequence> ScheduleComputation(
|
||||||
const HloComputation& computation,
|
HloComputation* computation,
|
||||||
const LogicalBuffer::SizeFunction& size_function);
|
const LogicalBuffer::SizeFunction& size_function);
|
||||||
|
|
||||||
// A pass which schedules the HLO instructions in a module. The HloModule's
|
// A pass which schedules the HLO instructions in a module. The HloModule's
|
||||||
|
@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
|
|||||||
TF_ASSERT_OK(module->schedule().Verify());
|
TF_ASSERT_OK(module->schedule().Verify());
|
||||||
|
|
||||||
// Verify that all instructions are in the sequence.
|
// Verify that all instructions are in the sequence.
|
||||||
const std::vector<const HloInstruction*>& sequence =
|
const std::vector<HloInstruction*>& sequence =
|
||||||
module->schedule().sequence(module->entry_computation()).instructions();
|
module->schedule().sequence(module->entry_computation()).instructions();
|
||||||
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
|
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
|
||||||
|
|
||||||
@ -124,9 +124,9 @@ ENTRY root {
|
|||||||
};
|
};
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, size_fn, ListMemoryScheduler));
|
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
|
||||||
// Verify that all instructions are in the sequence.
|
// Verify that all instructions are in the sequence.
|
||||||
const std::vector<const HloInstruction*>& sequence =
|
const std::vector<HloInstruction*>& sequence =
|
||||||
schedule.sequence(module->entry_computation()).instructions();
|
schedule.sequence(module->entry_computation()).instructions();
|
||||||
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
|
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
|
||||||
|
|
||||||
@ -175,12 +175,13 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
|
|||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
|
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
|
||||||
ScheduleModule(*module,
|
ScheduleModule(
|
||||||
[](const BufferValue& buffer) {
|
module.get(),
|
||||||
return ShapeUtil::ByteSizeOf(
|
[](const BufferValue& buffer) {
|
||||||
buffer.shape(), TUPLE_SIZE);
|
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||||
},
|
TUPLE_SIZE);
|
||||||
ListMemoryScheduler));
|
},
|
||||||
|
ListMemoryScheduler));
|
||||||
|
|
||||||
// Verify that all instructions are in the sequence.
|
// Verify that all instructions are in the sequence.
|
||||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||||
@ -225,12 +226,12 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
|
|||||||
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
|
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
|
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
|
||||||
ScheduleModule(*module,
|
ScheduleModule(
|
||||||
[](const BufferValue& buffer) {
|
module.get(),
|
||||||
return ShapeUtil::ByteSizeOf(
|
[](const BufferValue& buffer) {
|
||||||
buffer.shape(), 2);
|
return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
|
||||||
},
|
},
|
||||||
ListMemoryScheduler));
|
ListMemoryScheduler));
|
||||||
|
|
||||||
// Verify that all instructions are in the sequence.
|
// Verify that all instructions are in the sequence.
|
||||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||||
@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
|||||||
};
|
};
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, size_fn, ListMemoryScheduler));
|
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
|
||||||
// Verify that all instructions are in the sequence.
|
// Verify that all instructions are in the sequence.
|
||||||
auto entry_computation = module->entry_computation();
|
auto entry_computation = module->entry_computation();
|
||||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||||
|
@ -104,11 +104,7 @@ class HloModule {
|
|||||||
HloCloneContext* context = nullptr);
|
HloCloneContext* context = nullptr);
|
||||||
|
|
||||||
// Return a pointer to the entry computation of the module.
|
// Return a pointer to the entry computation of the module.
|
||||||
const HloComputation* entry_computation() const {
|
HloComputation* entry_computation() const {
|
||||||
CHECK_NE(nullptr, entry_computation_);
|
|
||||||
return entry_computation_;
|
|
||||||
}
|
|
||||||
HloComputation* entry_computation() {
|
|
||||||
CHECK_NE(nullptr, entry_computation_);
|
CHECK_NE(nullptr, entry_computation_);
|
||||||
return entry_computation_;
|
return entry_computation_;
|
||||||
}
|
}
|
||||||
|
@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() {
|
|||||||
// Create a map from instruction to its order position.
|
// Create a map from instruction to its order position.
|
||||||
TF_DCHECK_OK(schedule_.Verify());
|
TF_DCHECK_OK(schedule_.Verify());
|
||||||
for (const auto& computation_sequence : schedule_.sequences()) {
|
for (const auto& computation_sequence : schedule_.sequences()) {
|
||||||
const std::vector<const HloInstruction*>& order =
|
const auto& order = computation_sequence.second.instructions();
|
||||||
computation_sequence.second.instructions();
|
|
||||||
for (int i = 0; i < order.size(); ++i) {
|
for (int i = 0; i < order.size(); ++i) {
|
||||||
InsertOrDie(&order_position_, order[i], i);
|
InsertOrDie(&order_position_, order[i], i);
|
||||||
}
|
}
|
||||||
|
@ -47,11 +47,11 @@ const double kF16max = 65504;
|
|||||||
|
|
||||||
// Creates and returns a schedule created using the order of the instructions in
|
// Creates and returns a schedule created using the order of the instructions in
|
||||||
// the HloComputation::instructions() vectors in the module.
|
// the HloComputation::instructions() vectors in the module.
|
||||||
HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
|
HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
|
||||||
HloSchedule schedule(module);
|
HloSchedule schedule(module);
|
||||||
for (const HloComputation* computation : module->computations()) {
|
for (HloComputation* computation : module->computations()) {
|
||||||
if (!computation->IsFusionComputation()) {
|
if (!computation->IsFusionComputation()) {
|
||||||
for (const HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
schedule.GetOrCreateSequence(computation).push_back(instruction);
|
schedule.GetOrCreateSequence(computation).push_back(instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector<Item*, 3>;
|
|||||||
// before arbitrary elements.
|
// before arbitrary elements.
|
||||||
class InstructionList {
|
class InstructionList {
|
||||||
public:
|
public:
|
||||||
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
|
explicit InstructionList(const HloInstructionSequence& order) {
|
||||||
int64 position = 0;
|
int64 position = 0;
|
||||||
Item* last = nullptr;
|
Item* last = nullptr;
|
||||||
for (const HloInstruction* inst : order) {
|
for (HloInstruction* inst : order.instructions()) {
|
||||||
// Add a new item to the linked list.
|
// Add a new item to the linked list.
|
||||||
Item* item = new Item;
|
Item* item = new Item;
|
||||||
item->next = nullptr;
|
item->next = nullptr;
|
||||||
@ -151,7 +151,7 @@ class InstructionList {
|
|||||||
// to be monotonically increasing through the list, and so is still useful
|
// to be monotonically increasing through the list, and so is still useful
|
||||||
// for quickly(-ish) determining the order of arbitrary instructions in
|
// for quickly(-ish) determining the order of arbitrary instructions in
|
||||||
// the list.
|
// the list.
|
||||||
item->instruction = const_cast<HloInstruction*>(inst);
|
item->instruction = inst;
|
||||||
item->position = position;
|
item->position = position;
|
||||||
position++;
|
position++;
|
||||||
|
|
||||||
@ -927,7 +927,7 @@ Item* PickRematerializationCandidate(
|
|||||||
|
|
||||||
StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
||||||
const HloComputation* computation,
|
const HloComputation* computation,
|
||||||
const std::vector<const HloInstruction*>& order) const {
|
const HloInstructionSequence& order) const {
|
||||||
InstructionList instruction_list(order);
|
InstructionList instruction_list(order);
|
||||||
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
||||||
instruction_list);
|
instruction_list);
|
||||||
@ -971,8 +971,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
|
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
|
||||||
CHECK(!ContainsKey(rematerialized_computations_, computation));
|
CHECK(!ContainsKey(rematerialized_computations_, computation));
|
||||||
|
|
||||||
InstructionList instruction_list(
|
InstructionList instruction_list(schedule->sequence(computation));
|
||||||
schedule->sequence(computation).instructions());
|
|
||||||
MemoryUsageTracker memory_tracker(computation, size_function_,
|
MemoryUsageTracker memory_tracker(computation, size_function_,
|
||||||
*points_to_analysis_, instruction_list);
|
*points_to_analysis_, instruction_list);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
@ -1184,7 +1183,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
sequence.clear();
|
sequence.clear();
|
||||||
for (auto* item = instruction_list.first(); item != nullptr;
|
for (auto* item = instruction_list.first(); item != nullptr;
|
||||||
item = instruction_list.next(item)) {
|
item = instruction_list.next(item)) {
|
||||||
const HloInstruction* instruction = item->instruction;
|
HloInstruction* instruction = item->instruction;
|
||||||
sequence.push_back(instruction);
|
sequence.push_back(instruction);
|
||||||
}
|
}
|
||||||
rematerialized_computations_.insert(computation);
|
rematerialized_computations_.insert(computation);
|
||||||
@ -1235,10 +1234,8 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
|
|||||||
if (node.context() == CallContext::kSequential) {
|
if (node.context() == CallContext::kSequential) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
computation_peak_memory_[node.computation()],
|
computation_peak_memory_[node.computation()],
|
||||||
ComputePeakMemory(node.computation(),
|
ComputePeakMemory(node.computation(), module->schedule().sequence(
|
||||||
module->schedule()
|
node.computation())));
|
||||||
.sequence(node.computation())
|
|
||||||
.instructions()));
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
},
|
},
|
||||||
|
@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass {
|
|||||||
// peak memory is the maximum total size of all live HLO instruction values at
|
// peak memory is the maximum total size of all live HLO instruction values at
|
||||||
// any program point. 'order' is the order in which the HLO instructions will
|
// any program point. 'order' is the order in which the HLO instructions will
|
||||||
// be emitted which is used to determine lifespans of HLO values.
|
// be emitted which is used to determine lifespans of HLO values.
|
||||||
StatusOr<int64> ComputePeakMemory(
|
StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
|
||||||
const HloComputation* computation,
|
const HloInstructionSequence& order) const;
|
||||||
const std::vector<const HloInstruction*>& order) const;
|
|
||||||
|
|
||||||
// Returns the peak memory usage of the called computations for the given
|
// Returns the peak memory usage of the called computations for the given
|
||||||
// instruction. Zero is returned if the instruction calls no computations.
|
// instruction. Zero is returned if the instruction calls no computations.
|
||||||
|
@ -46,8 +46,8 @@ namespace xla {
|
|||||||
<< "No computation exists in HLO module with id " << computation_id;
|
<< "No computation exists in HLO module with id " << computation_id;
|
||||||
const HloComputation* computation = comp_it->second;
|
const HloComputation* computation = comp_it->second;
|
||||||
|
|
||||||
absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction;
|
absl::flat_hash_map<int64, HloInstruction*> id_to_instruction;
|
||||||
for (const HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
id_to_instruction[instruction->unique_id()] = instruction;
|
id_to_instruction[instruction->unique_id()] = instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,9 +81,8 @@ StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
|
|||||||
return std::move(proto);
|
return std::move(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloSchedule::set_sequence(
|
void HloSchedule::set_sequence(const HloComputation* computation,
|
||||||
const HloComputation* computation,
|
absl::Span<HloInstruction* const> sequence) {
|
||||||
absl::Span<const HloInstruction* const> sequence) {
|
|
||||||
set_sequence(computation, HloInstructionSequence(sequence));
|
set_sequence(computation, HloInstructionSequence(sequence));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule(
|
|||||||
const HloComputation* computation) {
|
const HloComputation* computation) {
|
||||||
// Map from unique ID to HloInstruction pointer for instructions in the
|
// Map from unique ID to HloInstruction pointer for instructions in the
|
||||||
// computation.
|
// computation.
|
||||||
absl::flat_hash_map<int, const HloInstruction*> id_to_instruction;
|
absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
|
||||||
for (const HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
|
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule(
|
|||||||
// Map from HloInstruction X to newly added instructions (instruction is in
|
// Map from HloInstruction X to newly added instructions (instruction is in
|
||||||
// computation, but not in schedule) which use X. If an instruction is not in
|
// computation, but not in schedule) which use X. If an instruction is not in
|
||||||
// the map, then it has no users which are newly added instructions.
|
// the map, then it has no users which are newly added instructions.
|
||||||
absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>>
|
absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
|
||||||
new_instruction_uses;
|
new_instruction_uses;
|
||||||
|
|
||||||
// For each newly added instruction, this is the count of the instruction's
|
// For each newly added instruction, this is the count of the instruction's
|
||||||
@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule(
|
|||||||
|
|
||||||
// Create a worklist of newly added instructions which are ready to be added
|
// Create a worklist of newly added instructions which are ready to be added
|
||||||
// to the schedule. Initialize worklist with those that have zero operands.
|
// to the schedule. Initialize worklist with those that have zero operands.
|
||||||
std::queue<const HloInstruction*> worklist;
|
std::queue<HloInstruction*> worklist;
|
||||||
|
|
||||||
for (const HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
if (ids_in_schedule.count(instruction->unique_id()) == 0) {
|
if (ids_in_schedule.count(instruction->unique_id()) == 0) {
|
||||||
// This is a newly added instruction which is not in the schedule.
|
// This is a newly added instruction which is not in the schedule.
|
||||||
if (instruction->operands().empty()) {
|
if (instruction->operands().empty()) {
|
||||||
@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule(
|
|||||||
// Lambda which schedules all instructions on the worklist.
|
// Lambda which schedules all instructions on the worklist.
|
||||||
auto schedule_worklist = [&]() {
|
auto schedule_worklist = [&]() {
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
const HloInstruction* instruction = worklist.front();
|
HloInstruction* instruction = worklist.front();
|
||||||
worklist.pop();
|
worklist.pop();
|
||||||
new_sequence.push_back(instruction);
|
new_sequence.push_back(instruction);
|
||||||
std::vector<const HloInstruction*>* new_users =
|
std::vector<HloInstruction*>* new_users =
|
||||||
tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
|
tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
|
||||||
if (new_users != nullptr) {
|
if (new_users != nullptr) {
|
||||||
// This just-scheduled instruction has users which are newly added to
|
// This just-scheduled instruction has users which are newly added to
|
||||||
// the module. Update the number of unscheduled operands and push the
|
// the module. Update the number of unscheduled operands and push the
|
||||||
// newly added instruction to the worklist if it is ready to
|
// newly added instruction to the worklist if it is ready to
|
||||||
// schedule.
|
// schedule.
|
||||||
for (const HloInstruction* new_user : *new_users) {
|
for (HloInstruction* new_user : *new_users) {
|
||||||
unscheduled_operand_count.at(new_user)--;
|
unscheduled_operand_count.at(new_user)--;
|
||||||
CHECK_GE(unscheduled_operand_count.at(new_user), 0);
|
CHECK_GE(unscheduled_operand_count.at(new_user), 0);
|
||||||
if (unscheduled_operand_count.at(new_user) == 0) {
|
if (unscheduled_operand_count.at(new_user) == 0) {
|
||||||
|
@ -35,14 +35,14 @@ class HloInstructionSequence {
|
|||||||
public:
|
public:
|
||||||
HloInstructionSequence() = default;
|
HloInstructionSequence() = default;
|
||||||
explicit HloInstructionSequence(
|
explicit HloInstructionSequence(
|
||||||
absl::Span<const HloInstruction* const> instructions) {
|
absl::Span<HloInstruction* const> instructions) {
|
||||||
for (const HloInstruction* instruction : instructions) {
|
for (HloInstruction* instruction : instructions) {
|
||||||
push_back(instruction);
|
push_back(instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds the instruction to the end of the sequence.
|
// Adds the instruction to the end of the sequence.
|
||||||
void push_back(const HloInstruction* instruction) {
|
void push_back(HloInstruction* instruction) {
|
||||||
instruction_sequence_.push_back(instruction);
|
instruction_sequence_.push_back(instruction);
|
||||||
id_sequence_.push_back(instruction->unique_id());
|
id_sequence_.push_back(instruction->unique_id());
|
||||||
}
|
}
|
||||||
@ -56,7 +56,7 @@ class HloInstructionSequence {
|
|||||||
int64 size() const { return instruction_sequence_.size(); }
|
int64 size() const { return instruction_sequence_.size(); }
|
||||||
|
|
||||||
// Returns the sequence of HLO instructions.
|
// Returns the sequence of HLO instructions.
|
||||||
const std::vector<const HloInstruction*>& instructions() const {
|
const std::vector<HloInstruction*>& instructions() const {
|
||||||
return instruction_sequence_;
|
return instruction_sequence_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class HloInstructionSequence {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// The sequence as HloInstructions.
|
// The sequence as HloInstructions.
|
||||||
std::vector<const HloInstruction*> instruction_sequence_;
|
std::vector<HloInstruction*> instruction_sequence_;
|
||||||
|
|
||||||
// The sequence of HLO instructions, represented by their unique IDs. The
|
// The sequence of HLO instructions, represented by their unique IDs. The
|
||||||
// sequence is stored as both HloInstructions and unique IDs because the
|
// sequence is stored as both HloInstructions and unique IDs because the
|
||||||
@ -98,7 +98,7 @@ class HloSchedule {
|
|||||||
|
|
||||||
// Sets the sequence for the given computation to the given sequence.
|
// Sets the sequence for the given computation to the given sequence.
|
||||||
void set_sequence(const HloComputation* computation,
|
void set_sequence(const HloComputation* computation,
|
||||||
absl::Span<const HloInstruction* const> sequence);
|
absl::Span<HloInstruction* const> sequence);
|
||||||
void set_sequence(const HloComputation* computation,
|
void set_sequence(const HloComputation* computation,
|
||||||
HloInstructionSequence sequence);
|
HloInstructionSequence sequence);
|
||||||
|
|
||||||
|
@ -56,10 +56,10 @@ ENTRY main {
|
|||||||
ParseHloString(module_str));
|
ParseHloString(module_str));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||||
}));
|
}));
|
||||||
const std::vector<const HloInstruction*>& entry_schedule =
|
const auto& entry_schedule =
|
||||||
schedule.sequence(module->entry_computation()).instructions();
|
schedule.sequence(module->entry_computation()).instructions();
|
||||||
|
|
||||||
EXPECT_EQ(entry_schedule.size(), 6);
|
EXPECT_EQ(entry_schedule.size(), 6);
|
||||||
@ -90,7 +90,7 @@ ENTRY main {
|
|||||||
ParseHloString(module_str));
|
ParseHloString(module_str));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ ENTRY main {
|
|||||||
ParseHloString(module_str));
|
ParseHloString(module_str));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ ENTRY main {
|
|||||||
ParseHloString(module_str));
|
ParseHloString(module_str));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] {
|
|||||||
ParseHloString(module_str));
|
ParseHloString(module_str));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||||
/*pointer_size=*/sizeof(void*));
|
/*pointer_size=*/sizeof(void*));
|
||||||
}));
|
}));
|
||||||
@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] {
|
|||||||
ParseHloString(module_str));
|
ParseHloString(module_str));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
ScheduleModule(*module, [](const BufferValue& buffer) {
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||||
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
||||||
/*pointer_size=*/sizeof(void*));
|
/*pointer_size=*/sizeof(void*));
|
||||||
}));
|
}));
|
||||||
|
Loading…
Reference in New Issue
Block a user