[NFC] Eliminate references to HLO Inst from ForThunk

- HLO Inst is used just for body profile index, so stash that directly in the Thunk.

PiperOrigin-RevId: 335690149
Change-Id: Iee12dfa1b52daf373d42d2fb45f49535bc687f0a
This commit is contained in:
Rahul Joshi 2020-10-06 12:03:50 -07:00 committed by TensorFlower Gardener
parent 04ea37b5ed
commit 96fb44b7b4
3 changed files with 17 additions and 9 deletions

View File

@ -24,15 +24,16 @@ namespace xla {
namespace gpu {
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence)
std::unique_ptr<ThunkSequence> body_thunk_sequence,
absl::optional<size_t> body_profile_index)
: Thunk(Kind::kWhile, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
loop_limit_(loop_limit),
body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it.
ThunkInfo(), std::move(*body_thunk_sequence))) {}
ThunkInfo(), std::move(*body_thunk_sequence))),
body_profile_index_(body_profile_index) {}
Status ForThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@ -41,15 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
}
Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
<< (hlo_instruction_ ? hlo_instruction_->ToString() : "<null>");
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
for (int64 i = 0; i < loop_limit_; ++i) {
params.profiler->StartHloComputation();
// Invoke loop body thunk sequence.
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
params.profiler->FinishHloComputation(hlo_instruction_->while_body());
params.profiler->FinishHloComputation(body_profile_index_);
}
return Status::OK();
}

View File

@ -32,7 +32,8 @@ namespace gpu {
class ForThunk : public Thunk {
public:
ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence);
std::unique_ptr<ThunkSequence> body_thunk_sequence,
absl::optional<size_t> body_profile_index_);
ForThunk(const ForThunk&) = delete;
ForThunk& operator=(const ForThunk&) = delete;
@ -41,9 +42,9 @@ class ForThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
const int64 loop_limit_;
std::unique_ptr<SequentialThunk> body_thunk_sequence_;
absl::optional<size_t> body_profile_index_;
};
} // namespace gpu

View File

@ -2272,8 +2272,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
const auto* index_map = ir_emitter_context_->profile_index_map();
absl::optional<size_t> body_profile_index;
if (index_map) {
body_profile_index = index_map->GetProfileIndexFor(*body);
}
return std::unique_ptr<Thunk>(new ForThunk(
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence()));
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(),
body_profile_index));
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(