[XLA/GPU] Decouple hlo_ordering from thunk_schedule.

The plumbing before this goes like this:
  * hlo_ordering -> buffer_assignment
  * buffer_assignment -> ir_emitter_unnested (DFS order) -> thunks
  * Apply hlo_ordering to thunks -> thunk_schedule

After:
  * hlo_ordering -> buffer_assignment
  * buffer_assignment -> ir_emitter_unnested (hlo_ordering) -> thunks
  * thunks -> thunk_schedule (order unchanged)

The idea is that since thunks are scheduled to the a certain total order anyway, just use that order to invoke the emitter. It saves an extra schedule, but most importantly, it removes uses of Thunk::hlo_instruction(), which makes MLIR/GPU transition easier.

PiperOrigin-RevId: 320117281
Change-Id: I0ee9ff14e71869ea09d6223ae10448317298096f
This commit is contained in:
Tim Shen 2020-07-07 20:55:32 -07:00 committed by TensorFlower Gardener
parent a7ee6a72ff
commit 82e12bf387
6 changed files with 20 additions and 19 deletions

View File

@ -518,8 +518,11 @@ static Status CompileModuleToLlvmIrImpl(
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
TF_RETURN_IF_ERROR(entry_computation->AcceptOrdered(
&ir_emitter, (*hlo_schedule)->ThunkLaunchOrder()));
}
// The order of `thunk_sequence` corresponds to
// `hlo_schedule->ThunkLaunchOrder()`.
*thunk_sequence = ir_emitter.ConsumeThunkSequence();
return Status::OK();
}
@ -610,8 +613,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
gpu_version, stream_exec));
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
std::move(thunk_sequence), std::move(stream_assignment),
hlo_schedule->ThunkLaunchOrder());
std::move(thunk_sequence), std::move(stream_assignment));
if (DumpingEnabledForHloModule(*module)) {
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
thunk_schedule->ToString());

View File

@ -49,21 +49,18 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands(
ThunkSchedule::ThunkSchedule(
std::unique_ptr<ThunkSequence> thunks,
std::unique_ptr<StreamAssignment> stream_assignment,
const std::vector<HloInstruction*>& hlo_total_order)
std::unique_ptr<StreamAssignment> stream_assignment)
: thunks_(std::move(thunks)),
stream_assignment_(std::move(stream_assignment)) {
for (auto& thunk : *thunks_) {
thunk_total_order_.push_back(thunk.get());
}
absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
for (const auto& thunk : *thunks_) {
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
}
for (HloInstruction* hlo : hlo_total_order) {
if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) {
thunk_total_order_.push_back(*thunk);
}
}
for (const Thunk* thunk : thunk_total_order_) {
const auto* dst = thunk->hlo_instruction();
CHECK(stream_assignment_->HasStreamAssigned(*dst));

View File

@ -47,8 +47,7 @@ namespace gpu {
class ThunkSchedule {
public:
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
std::unique_ptr<StreamAssignment> stream_assignment,
const std::vector<HloInstruction*>& hlo_total_order);
std::unique_ptr<StreamAssignment> stream_assignment);
// Returns the total order of executing all the thunks.
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }

View File

@ -226,8 +226,10 @@ absl::string_view LhloDialectEmitter::platform_name() const {
return platform_->Name();
}
Status LhloDialectEmitter::EmitComputation(const HloComputation& computation) {
return computation.root_instruction()->Accept(this);
Status LhloDialectEmitter::EmitComputation(
const HloComputation& computation,
absl::Span<HloInstruction* const> ordering) {
return computation.AcceptOrdered(this, ordering);
}
StatusOr<FuncOp> LhloDialectEmitter::CreateFunction(

View File

@ -47,7 +47,8 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
::mlir::ModuleOp mlir_module);
~LhloDialectEmitter() override = default;
Status EmitComputation(const HloComputation& computation);
Status EmitComputation(const HloComputation& computation,
absl::Span<HloInstruction* const> ordering);
// The following methods implement the DfsHloVisitor interface.
//

View File

@ -489,7 +489,8 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
stream_exec->platform(), *mlir_module);
TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation(
*emission_context.getHloModule()->entry_computation()));
*emission_context.getHloModule()->entry_computation(),
hlo_schedule->ThunkLaunchOrder()));
TF_RETURN_IF_ERROR(
module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
@ -539,8 +540,7 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
gpu::PtxOptsFromConfig(config)));
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
std::move(thunk_sequence), std::move(stream_assignment),
hlo_schedule->ThunkLaunchOrder());
std::move(thunk_sequence), std::move(stream_assignment));
if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",