[XLA/GPU] Decouple hlo_ordering from thunk_schedule.
The plumbing before this goes like this: * hlo_ordering -> buffer_assignment * buffer_assignment -> ir_emitter_unnested (DFS order) -> thunks * Apply hlo_ordering to thunks -> thunk_schedule After: * hlo_ordering -> buffer_assignment * buffer_assignment -> ir_emitter_unnested (hlo_ordering) -> thunks * thunks -> thunk_schedule (order unchanged) The idea is that since thunks are scheduled to the a certain total order anyway, just use that order to invoke the emitter. It saves an extra schedule, but most importantly, it removes uses of Thunk::hlo_instruction(), which makes MLIR/GPU transition easier. PiperOrigin-RevId: 320117281 Change-Id: I0ee9ff14e71869ea09d6223ae10448317298096f
This commit is contained in:
parent
a7ee6a72ff
commit
82e12bf387
@ -518,8 +518,11 @@ static Status CompileModuleToLlvmIrImpl(
|
|||||||
|
|
||||||
{
|
{
|
||||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
|
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
|
||||||
TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
|
TF_RETURN_IF_ERROR(entry_computation->AcceptOrdered(
|
||||||
|
&ir_emitter, (*hlo_schedule)->ThunkLaunchOrder()));
|
||||||
}
|
}
|
||||||
|
// The order of `thunk_sequence` corresponds to
|
||||||
|
// `hlo_schedule->ThunkLaunchOrder()`.
|
||||||
*thunk_sequence = ir_emitter.ConsumeThunkSequence();
|
*thunk_sequence = ir_emitter.ConsumeThunkSequence();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -610,8 +613,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
|||||||
gpu_version, stream_exec));
|
gpu_version, stream_exec));
|
||||||
|
|
||||||
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
||||||
std::move(thunk_sequence), std::move(stream_assignment),
|
std::move(thunk_sequence), std::move(stream_assignment));
|
||||||
hlo_schedule->ThunkLaunchOrder());
|
|
||||||
if (DumpingEnabledForHloModule(*module)) {
|
if (DumpingEnabledForHloModule(*module)) {
|
||||||
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
|
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
|
||||||
thunk_schedule->ToString());
|
thunk_schedule->ToString());
|
||||||
|
@ -49,21 +49,18 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands(
|
|||||||
|
|
||||||
ThunkSchedule::ThunkSchedule(
|
ThunkSchedule::ThunkSchedule(
|
||||||
std::unique_ptr<ThunkSequence> thunks,
|
std::unique_ptr<ThunkSequence> thunks,
|
||||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
std::unique_ptr<StreamAssignment> stream_assignment)
|
||||||
const std::vector<HloInstruction*>& hlo_total_order)
|
|
||||||
: thunks_(std::move(thunks)),
|
: thunks_(std::move(thunks)),
|
||||||
stream_assignment_(std::move(stream_assignment)) {
|
stream_assignment_(std::move(stream_assignment)) {
|
||||||
|
for (auto& thunk : *thunks_) {
|
||||||
|
thunk_total_order_.push_back(thunk.get());
|
||||||
|
}
|
||||||
|
|
||||||
absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
||||||
for (const auto& thunk : *thunks_) {
|
for (const auto& thunk : *thunks_) {
|
||||||
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
|
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (HloInstruction* hlo : hlo_total_order) {
|
|
||||||
if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) {
|
|
||||||
thunk_total_order_.push_back(*thunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const Thunk* thunk : thunk_total_order_) {
|
for (const Thunk* thunk : thunk_total_order_) {
|
||||||
const auto* dst = thunk->hlo_instruction();
|
const auto* dst = thunk->hlo_instruction();
|
||||||
CHECK(stream_assignment_->HasStreamAssigned(*dst));
|
CHECK(stream_assignment_->HasStreamAssigned(*dst));
|
||||||
|
@ -47,8 +47,7 @@ namespace gpu {
|
|||||||
class ThunkSchedule {
|
class ThunkSchedule {
|
||||||
public:
|
public:
|
||||||
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
|
ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
|
||||||
std::unique_ptr<StreamAssignment> stream_assignment,
|
std::unique_ptr<StreamAssignment> stream_assignment);
|
||||||
const std::vector<HloInstruction*>& hlo_total_order);
|
|
||||||
|
|
||||||
// Returns the total order of executing all the thunks.
|
// Returns the total order of executing all the thunks.
|
||||||
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
|
const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
|
||||||
|
@ -226,8 +226,10 @@ absl::string_view LhloDialectEmitter::platform_name() const {
|
|||||||
return platform_->Name();
|
return platform_->Name();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LhloDialectEmitter::EmitComputation(const HloComputation& computation) {
|
Status LhloDialectEmitter::EmitComputation(
|
||||||
return computation.root_instruction()->Accept(this);
|
const HloComputation& computation,
|
||||||
|
absl::Span<HloInstruction* const> ordering) {
|
||||||
|
return computation.AcceptOrdered(this, ordering);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<FuncOp> LhloDialectEmitter::CreateFunction(
|
StatusOr<FuncOp> LhloDialectEmitter::CreateFunction(
|
||||||
|
@ -47,7 +47,8 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
|
|||||||
::mlir::ModuleOp mlir_module);
|
::mlir::ModuleOp mlir_module);
|
||||||
~LhloDialectEmitter() override = default;
|
~LhloDialectEmitter() override = default;
|
||||||
|
|
||||||
Status EmitComputation(const HloComputation& computation);
|
Status EmitComputation(const HloComputation& computation,
|
||||||
|
absl::Span<HloInstruction* const> ordering);
|
||||||
|
|
||||||
// The following methods implement the DfsHloVisitor interface.
|
// The following methods implement the DfsHloVisitor interface.
|
||||||
//
|
//
|
||||||
|
@ -489,7 +489,8 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
|||||||
stream_exec->platform(), *mlir_module);
|
stream_exec->platform(), *mlir_module);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation(
|
TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation(
|
||||||
*emission_context.getHloModule()->entry_computation()));
|
*emission_context.getHloModule()->entry_computation(),
|
||||||
|
hlo_schedule->ThunkLaunchOrder()));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
|
module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
|
||||||
@ -539,8 +540,7 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
|||||||
gpu::PtxOptsFromConfig(config)));
|
gpu::PtxOptsFromConfig(config)));
|
||||||
|
|
||||||
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
|
||||||
std::move(thunk_sequence), std::move(stream_assignment),
|
std::move(thunk_sequence), std::move(stream_assignment));
|
||||||
hlo_schedule->ThunkLaunchOrder());
|
|
||||||
|
|
||||||
if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
|
if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
|
||||||
DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",
|
DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",
|
||||||
|
Loading…
Reference in New Issue
Block a user