[NFC] Eliminate references to HLO Inst from WhileThunk
- HLO Inst is used to get profile index for condition and body. Stash those inside the thunk instead and remove the HLO Inst pointer from the thunk. PiperOrigin-RevId: 335713572 Change-Id: I7be857bbe967fa256f4357ffdc807345e540f318
This commit is contained in:
parent
2f9ba0aae7
commit
993faabe64
@ -2252,11 +2252,19 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
|
||||
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
||||
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
||||
|
||||
const auto* index_map = ir_emitter_context_->profile_index_map();
|
||||
absl::optional<size_t> condition_profile_index, body_profile_index;
|
||||
if (index_map) {
|
||||
condition_profile_index = index_map->GetProfileIndexFor(*condition);
|
||||
body_profile_index = index_map->GetProfileIndexFor(*body);
|
||||
}
|
||||
|
||||
return std::unique_ptr<Thunk>(new WhileThunk(
|
||||
GetThunkInfo(hlo),
|
||||
GetAllocationSlice(*condition->root_instruction()), // cond result
|
||||
ir_emitter_condition->ConsumeThunkSequence(),
|
||||
ir_emitter_body->ConsumeThunkSequence()));
|
||||
ir_emitter_body->ConsumeThunkSequence(), condition_profile_index,
|
||||
body_profile_index));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
|
||||
|
@ -27,9 +27,10 @@ WhileThunk::WhileThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& condition_result_buffer_index,
|
||||
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence)
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
absl::optional<size_t> condition_profile_index,
|
||||
absl::optional<size_t> body_profile_index)
|
||||
: Thunk(Kind::kWhile, thunk_info),
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
condition_result_buffer_index_(condition_result_buffer_index),
|
||||
// Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
|
||||
// and body_thunk_sequence_ constructors because these SequentialThunks
|
||||
@ -38,7 +39,9 @@ WhileThunk::WhileThunk(
|
||||
condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||
ThunkInfo(), std::move(*condition_thunk_sequence))),
|
||||
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||
ThunkInfo(), std::move(*body_thunk_sequence))) {}
|
||||
ThunkInfo(), std::move(*body_thunk_sequence))),
|
||||
condition_profile_index_(condition_profile_index),
|
||||
body_profile_index_(body_profile_index) {}
|
||||
|
||||
Status WhileThunk::Initialize(const GpuExecutable& executable,
|
||||
se::StreamExecutor* executor) {
|
||||
@ -62,7 +65,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
profiler.StartHloComputation();
|
||||
VLOG(3) << "Executing condition computation";
|
||||
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params));
|
||||
profiler.FinishHloComputation(hlo_instruction_->while_condition());
|
||||
profiler.FinishHloComputation(condition_profile_index_);
|
||||
|
||||
// Copy the result of condition computation and break the loop if 'false'.
|
||||
bool condition_result;
|
||||
@ -86,7 +89,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
// Invoke thunk sequence for while 'body' computation, and pass on
|
||||
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
|
||||
profiler.FinishHloComputation(hlo_instruction_->while_body());
|
||||
profiler.FinishHloComputation(body_profile_index_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -42,7 +42,9 @@ class WhileThunk : public Thunk {
|
||||
WhileThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& condition_result_buffer_index,
|
||||
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence);
|
||||
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||
absl::optional<size_t> condition_profile_index,
|
||||
absl::optional<size_t> body_profile_index);
|
||||
WhileThunk(const WhileThunk&) = delete;
|
||||
WhileThunk& operator=(const WhileThunk&) = delete;
|
||||
|
||||
@ -55,6 +57,8 @@ class WhileThunk : public Thunk {
|
||||
const BufferAllocation::Slice condition_result_buffer_index_;
|
||||
std::unique_ptr<SequentialThunk> condition_thunk_sequence_;
|
||||
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
|
||||
absl::optional<size_t> condition_profile_index_;
|
||||
absl::optional<size_t> body_profile_index_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user