diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 4da05e140c5..8f01d7e3c41 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2252,11 +2252,19 @@ StatusOr> IrEmitterUnnested::BuildWhileThunk( IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + const auto* index_map = ir_emitter_context_->profile_index_map(); + absl::optional condition_profile_index, body_profile_index; + if (index_map) { + condition_profile_index = index_map->GetProfileIndexFor(*condition); + body_profile_index = index_map->GetProfileIndexFor(*body); + } + return std::unique_ptr(new WhileThunk( GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence())); + ir_emitter_body->ConsumeThunkSequence(), condition_profile_index, + body_profile_index)); } StatusOr> IrEmitterUnnested::BuildForThunk( diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 792479df4ac..6397ad3bee0 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -27,9 +27,10 @@ WhileThunk::WhileThunk( ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence) + std::unique_ptr body_thunk_sequence, + absl::optional condition_profile_index, + absl::optional body_profile_index) : Thunk(Kind::kWhile, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), condition_result_buffer_index_(condition_result_buffer_index), // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ // and body_thunk_sequence_ constructors because these SequentialThunks @@ -38,7 +39,9 @@ WhileThunk::WhileThunk( condition_thunk_sequence_(absl::make_unique( ThunkInfo(), std::move(*condition_thunk_sequence))), body_thunk_sequence_(absl::make_unique( - ThunkInfo(), std::move(*body_thunk_sequence))) {} + ThunkInfo(), std::move(*body_thunk_sequence))), + condition_profile_index_(condition_profile_index), + body_profile_index_(body_profile_index) {} Status WhileThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -62,7 +65,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { profiler.StartHloComputation(); VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction_->while_condition()); + profiler.FinishHloComputation(condition_profile_index_); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; @@ -86,7 +89,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction_->while_body()); + profiler.FinishHloComputation(body_profile_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 707bac15bb2..dc09c142a88 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -42,7 +42,9 @@ class WhileThunk : public Thunk { WhileThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence); + std::unique_ptr body_thunk_sequence, + absl::optional condition_profile_index, + absl::optional body_profile_index); WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; @@ -55,6 +57,8 @@ class WhileThunk : public Thunk { const BufferAllocation::Slice condition_result_buffer_index_; std::unique_ptr condition_thunk_sequence_; std::unique_ptr body_thunk_sequence_; + absl::optional condition_profile_index_; + absl::optional body_profile_index_; }; } // namespace gpu