[NFC] Eliminate references to HLO Inst from WhileThunk
- HLO Inst is used to get profile index for condition and body. Stash those inside the thunk instead and remove the HLO Inst pointer from the thunk. PiperOrigin-RevId: 335713572 Change-Id: I7be857bbe967fa256f4357ffdc807345e540f318
This commit is contained in:
parent
2f9ba0aae7
commit
993faabe64
@ -2252,11 +2252,19 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
|
|||||||
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
||||||
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
||||||
|
|
||||||
|
const auto* index_map = ir_emitter_context_->profile_index_map();
|
||||||
|
absl::optional<size_t> condition_profile_index, body_profile_index;
|
||||||
|
if (index_map) {
|
||||||
|
condition_profile_index = index_map->GetProfileIndexFor(*condition);
|
||||||
|
body_profile_index = index_map->GetProfileIndexFor(*body);
|
||||||
|
}
|
||||||
|
|
||||||
return std::unique_ptr<Thunk>(new WhileThunk(
|
return std::unique_ptr<Thunk>(new WhileThunk(
|
||||||
GetThunkInfo(hlo),
|
GetThunkInfo(hlo),
|
||||||
GetAllocationSlice(*condition->root_instruction()), // cond result
|
GetAllocationSlice(*condition->root_instruction()), // cond result
|
||||||
ir_emitter_condition->ConsumeThunkSequence(),
|
ir_emitter_condition->ConsumeThunkSequence(),
|
||||||
ir_emitter_body->ConsumeThunkSequence()));
|
ir_emitter_body->ConsumeThunkSequence(), condition_profile_index,
|
||||||
|
body_profile_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
|
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
|
||||||
|
|||||||
@ -27,9 +27,10 @@ WhileThunk::WhileThunk(
|
|||||||
ThunkInfo thunk_info,
|
ThunkInfo thunk_info,
|
||||||
const BufferAllocation::Slice& condition_result_buffer_index,
|
const BufferAllocation::Slice& condition_result_buffer_index,
|
||||||
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
||||||
std::unique_ptr<ThunkSequence> body_thunk_sequence)
|
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||||
|
absl::optional<size_t> condition_profile_index,
|
||||||
|
absl::optional<size_t> body_profile_index)
|
||||||
: Thunk(Kind::kWhile, thunk_info),
|
: Thunk(Kind::kWhile, thunk_info),
|
||||||
hlo_instruction_(thunk_info.hlo_instruction),
|
|
||||||
condition_result_buffer_index_(condition_result_buffer_index),
|
condition_result_buffer_index_(condition_result_buffer_index),
|
||||||
// Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
|
// Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
|
||||||
// and body_thunk_sequence_ constructors because these SequentialThunks
|
// and body_thunk_sequence_ constructors because these SequentialThunks
|
||||||
@ -38,7 +39,9 @@ WhileThunk::WhileThunk(
|
|||||||
condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||||
ThunkInfo(), std::move(*condition_thunk_sequence))),
|
ThunkInfo(), std::move(*condition_thunk_sequence))),
|
||||||
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||||
ThunkInfo(), std::move(*body_thunk_sequence))) {}
|
ThunkInfo(), std::move(*body_thunk_sequence))),
|
||||||
|
condition_profile_index_(condition_profile_index),
|
||||||
|
body_profile_index_(body_profile_index) {}
|
||||||
|
|
||||||
Status WhileThunk::Initialize(const GpuExecutable& executable,
|
Status WhileThunk::Initialize(const GpuExecutable& executable,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
@ -62,7 +65,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
profiler.StartHloComputation();
|
profiler.StartHloComputation();
|
||||||
VLOG(3) << "Executing condition computation";
|
VLOG(3) << "Executing condition computation";
|
||||||
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params));
|
TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params));
|
||||||
profiler.FinishHloComputation(hlo_instruction_->while_condition());
|
profiler.FinishHloComputation(condition_profile_index_);
|
||||||
|
|
||||||
// Copy the result of condition computation and break the loop if 'false'.
|
// Copy the result of condition computation and break the loop if 'false'.
|
||||||
bool condition_result;
|
bool condition_result;
|
||||||
@ -86,7 +89,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
// Invoke thunk sequence for while 'body' computation, and pass on
|
// Invoke thunk sequence for while 'body' computation, and pass on
|
||||||
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
|
// 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
|
||||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
|
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
|
||||||
profiler.FinishHloComputation(hlo_instruction_->while_body());
|
profiler.FinishHloComputation(body_profile_index_);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,7 +42,9 @@ class WhileThunk : public Thunk {
|
|||||||
WhileThunk(ThunkInfo thunk_info,
|
WhileThunk(ThunkInfo thunk_info,
|
||||||
const BufferAllocation::Slice& condition_result_buffer_index,
|
const BufferAllocation::Slice& condition_result_buffer_index,
|
||||||
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
std::unique_ptr<ThunkSequence> condition_thunk_sequence,
|
||||||
std::unique_ptr<ThunkSequence> body_thunk_sequence);
|
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||||
|
absl::optional<size_t> condition_profile_index,
|
||||||
|
absl::optional<size_t> body_profile_index);
|
||||||
WhileThunk(const WhileThunk&) = delete;
|
WhileThunk(const WhileThunk&) = delete;
|
||||||
WhileThunk& operator=(const WhileThunk&) = delete;
|
WhileThunk& operator=(const WhileThunk&) = delete;
|
||||||
|
|
||||||
@ -55,6 +57,8 @@ class WhileThunk : public Thunk {
|
|||||||
const BufferAllocation::Slice condition_result_buffer_index_;
|
const BufferAllocation::Slice condition_result_buffer_index_;
|
||||||
std::unique_ptr<SequentialThunk> condition_thunk_sequence_;
|
std::unique_ptr<SequentialThunk> condition_thunk_sequence_;
|
||||||
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
|
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
|
||||||
|
absl::optional<size_t> condition_profile_index_;
|
||||||
|
absl::optional<size_t> body_profile_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user