[NFC] Eliminate references to HLO Inst from WhileThunk

- HLO Inst is used to get profile index for condition and body. Stash those inside the
  thunk instead and remove the HLO Inst pointer from the thunk.

PiperOrigin-RevId: 335713572
Change-Id: I7be857bbe967fa256f4357ffdc807345e540f318
This commit is contained in:
Rahul Joshi 2020-10-06 13:51:57 -07:00 committed by TensorFlower Gardener
parent 2f9ba0aae7
commit 993faabe64
3 changed files with 22 additions and 7 deletions

View File

@ -2252,11 +2252,19 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
const auto* index_map = ir_emitter_context_->profile_index_map();
absl::optional<size_t> condition_profile_index, body_profile_index;
if (index_map) {
condition_profile_index = index_map->GetProfileIndexFor(*condition);
body_profile_index = index_map->GetProfileIndexFor(*body);
}
return std::unique_ptr<Thunk>(new WhileThunk(
GetThunkInfo(hlo),
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition->ConsumeThunkSequence(),
ir_emitter_body->ConsumeThunkSequence()));
ir_emitter_body->ConsumeThunkSequence(), condition_profile_index,
body_profile_index));
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(

View File

@ -27,9 +27,10 @@ WhileThunk::WhileThunk(
ThunkInfo thunk_info,
const BufferAllocation::Slice& condition_result_buffer_index,
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
std::unique_ptr<ThunkSequence> body_thunk_sequence)
std::unique_ptr<ThunkSequence> body_thunk_sequence,
absl::optional<size_t> condition_profile_index,
absl::optional<size_t> body_profile_index)
: Thunk(Kind::kWhile, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
condition_result_buffer_index_(condition_result_buffer_index),
// Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
// and body_thunk_sequence_ constructors because these SequentialThunks
@ -38,7 +39,9 @@ WhileThunk::WhileThunk(
condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
ThunkInfo(), std::move(*condition_thunk_sequence))),
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
ThunkInfo(), std::move(*body_thunk_sequence))) {}
ThunkInfo(), std::move(*body_thunk_sequence))),
condition_profile_index_(condition_profile_index),
body_profile_index_(body_profile_index) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@ -62,7 +65,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
profiler.StartHloComputation();
VLOG(3) << "Executing condition computation";
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params));
profiler.FinishHloComputation(hlo_instruction_->while_condition());
profiler.FinishHloComputation(condition_profile_index_);
// Copy the result of condition computation and break the loop if 'false'.
bool condition_result;
@ -86,7 +89,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
// Invoke thunk sequence for while 'body' computation, and pass on
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
profiler.FinishHloComputation(hlo_instruction_->while_body());
profiler.FinishHloComputation(body_profile_index_);
}
return Status::OK();
}

View File

@ -42,7 +42,9 @@ class WhileThunk : public Thunk {
WhileThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& condition_result_buffer_index,
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
std::unique_ptr<ThunkSequence> body_thunk_sequence);
std::unique_ptr<ThunkSequence> body_thunk_sequence,
absl::optional<size_t> condition_profile_index,
absl::optional<size_t> body_profile_index);
WhileThunk(const WhileThunk&) = delete;
WhileThunk& operator=(const WhileThunk&) = delete;
@ -55,6 +57,8 @@ class WhileThunk : public Thunk {
const BufferAllocation::Slice condition_result_buffer_index_;
std::unique_ptr<SequentialThunk> condition_thunk_sequence_;
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
absl::optional<size_t> condition_profile_index_;
absl::optional<size_t> body_profile_index_;
};
} // namespace gpu