[NFC] Eliminate references to HLO Inst from ForThunk

- HLO Inst is used just for body profile index, so stash that directly in the Thunk.

PiperOrigin-RevId: 335690149
Change-Id: Iee12dfa1b52daf373d42d2fb45f49535bc687f0a
This commit is contained in:
Rahul Joshi 2020-10-06 12:03:50 -07:00 committed by TensorFlower Gardener
parent 04ea37b5ed
commit 96fb44b7b4
3 changed files with 17 additions and 9 deletions

View File

@ -24,15 +24,16 @@ namespace xla {
namespace gpu { namespace gpu {
ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence) std::unique_ptr<ThunkSequence> body_thunk_sequence,
absl::optional<size_t> body_profile_index)
: Thunk(Kind::kWhile, thunk_info), : Thunk(Kind::kWhile, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
loop_limit_(loop_limit), loop_limit_(loop_limit),
body_thunk_sequence_(absl::make_unique<SequentialThunk>( body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of" // constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it. // this ForThunk, and shouldn't be profiled separately from it.
ThunkInfo(), std::move(*body_thunk_sequence))) {} ThunkInfo(), std::move(*body_thunk_sequence))),
body_profile_index_(body_profile_index) {}
Status ForThunk::Initialize(const GpuExecutable& executable, Status ForThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
@ -41,15 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
} }
Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { Status ForThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters";
<< (hlo_instruction_ ? hlo_instruction_->ToString() : "<null>");
auto op_profiler = auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index()); params.profiler->MakeScopedInstructionProfiler(profile_index());
for (int64 i = 0; i < loop_limit_; ++i) { for (int64 i = 0; i < loop_limit_; ++i) {
params.profiler->StartHloComputation(); params.profiler->StartHloComputation();
// Invoke loop body thunk sequence. // Invoke loop body thunk sequence.
TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params));
params.profiler->FinishHloComputation(hlo_instruction_->while_body()); params.profiler->FinishHloComputation(body_profile_index_);
} }
return Status::OK(); return Status::OK();
} }

View File

@ -32,7 +32,8 @@ namespace gpu {
class ForThunk : public Thunk { class ForThunk : public Thunk {
public: public:
ForThunk(ThunkInfo thunk_info, const int64 loop_limit, ForThunk(ThunkInfo thunk_info, const int64 loop_limit,
std::unique_ptr<ThunkSequence> body_thunk_sequence); std::unique_ptr<ThunkSequence> body_thunk_sequence,
absl::optional<size_t> body_profile_index_);
ForThunk(const ForThunk&) = delete; ForThunk(const ForThunk&) = delete;
ForThunk& operator=(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete;
@ -41,9 +42,9 @@ class ForThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;
private: private:
const HloInstruction* hlo_instruction_;
const int64 loop_limit_; const int64 loop_limit_;
std::unique_ptr<SequentialThunk> body_thunk_sequence_; std::unique_ptr<SequentialThunk> body_thunk_sequence_;
absl::optional<size_t> body_profile_index_;
}; };
} // namespace gpu } // namespace gpu

View File

@ -2272,8 +2272,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
const auto* index_map = ir_emitter_context_->profile_index_map();
absl::optional<size_t> body_profile_index;
if (index_map) {
body_profile_index = index_map->GetProfileIndexFor(*body);
}
return std::unique_ptr<Thunk>(new ForThunk( return std::unique_ptr<Thunk>(new ForThunk(
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(),
body_profile_index));
} }
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk( StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(