[NFC] Eliminate references to HLO Inst from ForThunk
- HLO Inst is used just for body profile index, so stash that directly in the Thunk. PiperOrigin-RevId: 335690149 Change-Id: Iee12dfa1b52daf373d42d2fb45f49535bc687f0a
This commit is contained in:
parent
04ea37b5ed
commit
96fb44b7b4
@ -24,15 +24,16 @@ namespace xla {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
||||||
std::unique_ptr<ThunkSequence> body_thunk_sequence)
|
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||||
|
absl::optional<size_t> body_profile_index)
|
||||||
: Thunk(Kind::kWhile, thunk_info),
|
: Thunk(Kind::kWhile, thunk_info),
|
||||||
hlo_instruction_(thunk_info.hlo_instruction),
|
|
||||||
loop_limit_(loop_limit),
|
loop_limit_(loop_limit),
|
||||||
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
|
||||||
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
|
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
|
||||||
// constructor because this SequentialThunk is logically "part of"
|
// constructor because this SequentialThunk is logically "part of"
|
||||||
// this ForThunk, and shouldn't be profiled separately from it.
|
// this ForThunk, and shouldn't be profiled separately from it.
|
||||||
ThunkInfo(), std::move(*body_thunk_sequence))) {}
|
ThunkInfo(), std::move(*body_thunk_sequence))),
|
||||||
|
body_profile_index_(body_profile_index) {}
|
||||||
|
|
||||||
Status ForThunk::Initialize(const GpuExecutable& executable,
|
Status ForThunk::Initialize(const GpuExecutable& executable,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
@ -41,15 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
|
Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||||
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
|
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters";
|
||||||
<< (hlo_instruction_ ? hlo_instruction_->ToString() : "<null>");
|
|
||||||
auto op_profiler =
|
auto op_profiler =
|
||||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||||
for (int64 i = 0; i < loop_limit_; ++i) {
|
for (int64 i = 0; i < loop_limit_; ++i) {
|
||||||
params.profiler->StartHloComputation();
|
params.profiler->StartHloComputation();
|
||||||
// Invoke loop body thunk sequence.
|
// Invoke loop body thunk sequence.
|
||||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
|
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
|
||||||
params.profiler->FinishHloComputation(hlo_instruction_->while_body());
|
params.profiler->FinishHloComputation(body_profile_index_);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,8 @@ namespace gpu {
|
|||||||
class ForThunk : public Thunk {
|
class ForThunk : public Thunk {
|
||||||
public:
|
public:
|
||||||
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
|
||||||
std::unique_ptr<ThunkSequence> body_thunk_sequence);
|
std::unique_ptr<ThunkSequence> body_thunk_sequence,
|
||||||
|
absl::optional<size_t> body_profile_index_);
|
||||||
ForThunk(const ForThunk&) = delete;
|
ForThunk(const ForThunk&) = delete;
|
||||||
ForThunk& operator=(const ForThunk&) = delete;
|
ForThunk& operator=(const ForThunk&) = delete;
|
||||||
|
|
||||||
@ -41,9 +42,9 @@ class ForThunk : public Thunk {
|
|||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const HloInstruction* hlo_instruction_;
|
|
||||||
const int64 loop_limit_;
|
const int64 loop_limit_;
|
||||||
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
|
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
|
||||||
|
absl::optional<size_t> body_profile_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -2272,8 +2272,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
|
|||||||
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
||||||
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
||||||
|
|
||||||
|
const auto* index_map = ir_emitter_context_->profile_index_map();
|
||||||
|
absl::optional<size_t> body_profile_index;
|
||||||
|
if (index_map) {
|
||||||
|
body_profile_index = index_map->GetProfileIndexFor(*body);
|
||||||
|
}
|
||||||
|
|
||||||
return std::unique_ptr<Thunk>(new ForThunk(
|
return std::unique_ptr<Thunk>(new ForThunk(
|
||||||
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence()));
|
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(),
|
||||||
|
body_profile_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
|
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user