From 96fb44b7b4c2fb4c1ec15ef1c42b649ab3d0999e Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 6 Oct 2020 12:03:50 -0700 Subject: [PATCH] [NFC] Eliminate references to HLO Inst from ForThunk - HLO Inst is used just for body profile index, so stash that directly in the Thunk. PiperOrigin-RevId: 335690149 Change-Id: Iee12dfa1b52daf373d42d2fb45f49535bc687f0a --- tensorflow/compiler/xla/service/gpu/for_thunk.cc | 12 ++++++------ tensorflow/compiler/xla/service/gpu/for_thunk.h | 5 +++-- .../compiler/xla/service/gpu/ir_emitter_unnested.cc | 9 ++++++++- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index ccd661d8ade..a9e6cd05c31 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -24,15 +24,16 @@ namespace xla { namespace gpu { ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, - std::unique_ptr body_thunk_sequence) + std::unique_ptr body_thunk_sequence, + absl::optional body_profile_index) : Thunk(Kind::kWhile, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), loop_limit_(loop_limit), body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. - ThunkInfo(), std::move(*body_thunk_sequence))) {} + ThunkInfo(), std::move(*body_thunk_sequence))), + body_profile_index_(body_profile_index) {} Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -41,15 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " - << (hlo_instruction_ ? hlo_instruction_->ToString() : ""); + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); for (int64 i = 0; i < loop_limit_; ++i) { params.profiler->StartHloComputation(); // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - params.profiler->FinishHloComputation(hlo_instruction_->while_body()); + params.profiler->FinishHloComputation(body_profile_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index b6ee950737e..96f0534cd52 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -32,7 +32,8 @@ namespace gpu { class ForThunk : public Thunk { public: ForThunk(ThunkInfo thunk_info, const int64 loop_limit, - std::unique_ptr body_thunk_sequence); + std::unique_ptr body_thunk_sequence, + absl::optional body_profile_index_); ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; @@ -41,9 +42,9 @@ class ForThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const int64 loop_limit_; std::unique_ptr body_thunk_sequence_; + absl::optional body_profile_index_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 7459cb68955..4da05e140c5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2272,8 +2272,15 @@ StatusOr> IrEmitterUnnested::BuildForThunk( IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + const auto* index_map = ir_emitter_context_->profile_index_map(); + absl::optional body_profile_index; + if (index_map) { + body_profile_index = index_map->GetProfileIndexFor(*body); + } + return std::unique_ptr(new ForThunk( - GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); + GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(), + body_profile_index)); } StatusOr> IrEmitterUnnested::BuildConditionalThunk(