From 5bbf4a1d112885f281b92ae94124ce929ea1e969 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Fri, 10 Jul 2020 15:25:28 -0700 Subject: [PATCH] [XLA/GPU] Remove uses of Thunk::hlo_instruction() for profiling. This CL consists of two steps: * First, refactor all Thunks to take an ThunkInfo instead of const HloInstruction*. This will benefit future extensions to ThunkInfo as we move away from HloInstruction*. * Secondly, change the data pipeline from: Emitter -> Thunk* -> hlo_instruction() -> profiler(HloInstruction*) to: Emitter -> Thunk with profile indices The profile doesn't really depend on HloInstruction*, but just its pointer identity. Removing the dependency on HloInstruction helps with MLIR migration. PiperOrigin-RevId: 320687291 Change-Id: I7027d4c032f73ed615e5b520e01f3740781735be --- tensorflow/compiler/xla/service/gpu/BUILD | 3 +- .../xla/service/gpu/cholesky_thunk.cc | 15 +-- .../compiler/xla/service/gpu/cholesky_thunk.h | 7 +- .../service/gpu/collective_permute_thunk.cc | 8 +- .../service/gpu/collective_permute_thunk.h | 6 +- .../xla/service/gpu/conditional_thunk.cc | 14 +-- .../xla/service/gpu/conditional_thunk.h | 4 +- .../xla/service/gpu/convolution_thunk.cc | 13 +-- .../xla/service/gpu/convolution_thunk.h | 2 +- .../compiler/xla/service/gpu/copy_thunk.cc | 18 ++-- .../compiler/xla/service/gpu/copy_thunk.h | 10 +- .../xla/service/gpu/cudnn_batchnorm_thunk.cc | 27 +++--- .../xla/service/gpu/cudnn_batchnorm_thunk.h | 16 ++-- .../xla/service/gpu/custom_call_thunk.cc | 8 +- .../xla/service/gpu/custom_call_thunk.h | 5 +- .../xla/service/gpu/dummy_all_reduce_thunk.cc | 6 +- .../compiler/xla/service/gpu/fft_thunk.cc | 10 +- .../compiler/xla/service/gpu/fft_thunk.h | 6 +- .../compiler/xla/service/gpu/for_thunk.cc | 11 +-- .../compiler/xla/service/gpu/for_thunk.h | 5 +- .../xla/service/gpu/gemm_algorithm_picker.cc | 1 + .../compiler/xla/service/gpu/gemm_thunk.cc | 11 ++- .../compiler/xla/service/gpu/gemm_thunk.h | 6 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 58 ++++++------ .../xla/service/gpu/hlo_execution_profiler.cc | 23 ++--- .../xla/service/gpu/hlo_execution_profiler.h | 24 ++--- .../compiler/xla/service/gpu/infeed_thunk.cc | 8 +- .../compiler/xla/service/gpu/infeed_thunk.h | 4 +- .../xla/service/gpu/ir_emitter_context.h | 6 +- .../xla/service/gpu/ir_emitter_unnested.cc | 92 +++++++++++-------- .../xla/service/gpu/ir_emitter_unnested.h | 2 + .../compiler/xla/service/gpu/kernel_thunk.cc | 10 +- .../compiler/xla/service/gpu/kernel_thunk.h | 5 +- .../compiler/xla/service/gpu/memset_thunk.cc | 4 +- .../compiler/xla/service/gpu/memset_thunk.h | 15 +-- .../xla/service/gpu/nccl_all_reduce_thunk.cc | 8 +- .../xla/service/gpu/nccl_all_reduce_thunk.h | 4 +- .../compiler/xla/service/gpu/outfeed_thunk.cc | 8 +- .../compiler/xla/service/gpu/outfeed_thunk.h | 4 +- .../xla/service/gpu/replica_id_thunk.cc | 8 +- .../xla/service/gpu/replica_id_thunk.h | 3 +- .../xla/service/gpu/sequential_thunk.cc | 8 +- .../xla/service/gpu/sequential_thunk.h | 4 +- tensorflow/compiler/xla/service/gpu/thunk.h | 20 +++- .../compiler/xla/service/gpu/thunk_emitter.cc | 72 +++++++++------ .../compiler/xla/service/gpu/thunk_emitter.h | 1 + .../xla/service/gpu/triangular_solve_thunk.cc | 6 +- .../xla/service/gpu/triangular_solve_thunk.h | 6 +- .../compiler/xla/service/gpu/tuple_thunk.cc | 2 +- .../compiler/xla/service/gpu/tuple_thunk.h | 8 +- .../compiler/xla/service/gpu/while_thunk.cc | 12 +-- .../compiler/xla/service/gpu/while_thunk.h | 6 +- .../xla/service/hlo_execution_profile.cc | 8 +- .../xla/service/hlo_execution_profile.h | 3 + .../service/mlir_gpu/lhlo_dialect_emitter.h | 1 + .../service/mlir_gpu/mlir_compiler_impl.cc | 6 +- 56 files changed, 366 insertions(+), 295 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 785122e23b4..b22f258bac6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -242,7 +242,6 @@ cc_library( deps = [ ":backend_configs_cc", ":buffer_allocations", - ":cudnn_batchnorm_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_conv_runner", @@ -267,6 +266,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:while_loop_analysis", @@ -282,7 +282,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:sort_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc index b35cbb598cc..c34c299fea8 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc @@ -31,13 +31,13 @@ limitations under the License. namespace xla { namespace gpu { -CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, +CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, + const CholeskyOptions& options, BufferAllocation::Slice a_buffer, BufferAllocation::Slice workspace_buffer, BufferAllocation::Slice info_buffer, - PrimitiveType type, int64 batch_size, int64 n, - const HloInstruction* hlo) - : Thunk(Kind::kCholesky, hlo), + PrimitiveType type, int64 batch_size, int64 n) + : Thunk(Kind::kCholesky, thunk_info), uplo_(options.lower() ? se::blas::UpperLower::kLower : se::blas::UpperLower::kUpper), a_buffer_(a_buffer), @@ -45,9 +45,10 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, info_buffer_(info_buffer), type_(type), batch_size_(batch_size), - a_batch_stride_(n * n * - ShapeUtil::ByteSizeOfPrimitiveType( - hlo->operand(0)->shape().element_type())), + a_batch_stride_( + n * n * + ShapeUtil::ByteSizeOfPrimitiveType( + thunk_info.hlo_instruction->operand(0)->shape().element_type())), n_(n) {} Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h index 50ecca51588..9950d09d765 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.h @@ -41,12 +41,11 @@ namespace gpu { class CholeskyThunk : public Thunk { public: static StatusOr ScratchBufferSize(int64 n); - CholeskyThunk(const CholeskyOptions& options, + CholeskyThunk(ThunkInfo thunk_info, const CholeskyOptions& options, BufferAllocation::Slice a_buffer, BufferAllocation::Slice workspace_buffer, - BufferAllocation::Slice info_buffer, - PrimitiveType type, - int64 batch_size, int64 n, const HloInstruction* hlo); + BufferAllocation::Slice info_buffer, PrimitiveType type, + int64 batch_size, int64 n); CholeskyThunk(const CholeskyThunk&) = delete; CholeskyThunk& operator=(const CholeskyThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index a5001d5168d..bb76bf02eba 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -218,14 +218,14 @@ RefcountingHashMap& GlobalRendezvousMap() { } // anonymous namespace CollectivePermuteThunk::CollectivePermuteThunk( - const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest, - const HloInstruction* instr) - : Thunk(kCollectivePermute, instr), src_(src), dest_(dest) {} + ThunkInfo thunk_info, const BufferAllocation::Slice& src, + const BufferAllocation::Slice& dest) + : Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {} Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { auto* instr = Cast(hlo_instruction()); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); // Rendezvous with the threads for all other devices that are participating in // this CollectivePermute. diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h index 3d451bc03f4..329db00c66a 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h @@ -26,9 +26,9 @@ namespace gpu { // Thunk that implements the collective-permute HLO. class CollectivePermuteThunk : public Thunk { public: - CollectivePermuteThunk(const BufferAllocation::Slice& src, - const BufferAllocation::Slice& dest, - const HloInstruction* instr); + CollectivePermuteThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& src, + const BufferAllocation::Slice& dest); Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index e31f45942b1..041aa9b6fa3 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -24,12 +24,14 @@ namespace xla { namespace gpu { ConditionalThunk::ConditionalThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice& branch_index_buffer_index, absl::Span branch_operand_buffer_indexes, - std::vector branch_thunk_sequences, - const HloInstruction* hlo) - : Thunk(Kind::kConditional, hlo), - branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED), + std::vector branch_thunk_sequences) + : Thunk(Kind::kConditional, thunk_info), + branch_index_is_bool_( + thunk_info.hlo_instruction->operand(0)->shape().element_type() == + PRED), branch_index_buffer_index_(branch_index_buffer_index), branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), branch_operand_buffer_indexes.end()) { @@ -39,7 +41,7 @@ ConditionalThunk::ConditionalThunk( branch_thunks_.reserve(branch_thunk_sequences.size()); for (auto& branch_thunk_sequence : branch_thunk_sequences) { branch_thunks_.emplace_back( - new SequentialThunk(std::move(branch_thunk_sequence), nullptr)); + new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence))); } } @@ -67,7 +69,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { auto& profiler = *params.profiler; auto& stream = *params.stream; - auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction()); + auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index()); // Copy the predicate value from device. int32 branch_index = -1; bool pred = false; diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index 404e2131eff..a00285efa7c 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -43,10 +43,10 @@ namespace gpu { class ConditionalThunk : public Thunk { public: ConditionalThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice& branch_index_buffer_index, absl::Span branch_operand_buffer_indexes, - std::vector branch_thunk_sequences, - const HloInstruction* hlo); + std::vector branch_thunk_sequences); ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 9a47f13db3e..df3dd6d4593 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" @@ -30,16 +31,16 @@ namespace xla { namespace gpu { ConvolutionThunk::ConvolutionThunk( - const HloCustomCallInstruction* cudnn_call, - std::vector operand_slices, + ThunkInfo thunk_info, std::vector operand_slices, BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, BufferAllocation::Slice tuple_result_slice) - : Thunk(Kind::kConvolution, cudnn_call), - cudnn_call_(cudnn_call), + : Thunk(Kind::kConvolution, thunk_info), operand_buffers_(std::move(operand_slices)), result_buffer_(result_slice), scratch_buffer_(scratch_slice), - tuple_result_buffer_(tuple_result_slice) {} + tuple_result_buffer_(tuple_result_slice) { + cudnn_call_ = Cast(hlo_instruction()); +} Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; @@ -56,7 +57,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(scratch_buffer_); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers), result_buffer, scratch, params.stream)); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index a65a9a7b36d..03fae88c6dc 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk { // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // // operand_slices should be in the same order as cudnn_call->operands(). - ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + ConvolutionThunk(ThunkInfo thunk_info, std::vector operand_slices, BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index 763bfc24813..9439937d599 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -22,10 +22,9 @@ namespace xla { namespace gpu { HostToDeviceCopyThunk::HostToDeviceCopyThunk( - const void* source_address, - const BufferAllocation::Slice& destination_buffer, uint64 mem_size, - const HloInstruction* hlo_instruction) - : Thunk(Kind::kCopy, hlo_instruction), + ThunkInfo thunk_info, const void* source_address, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size) + : Thunk(Kind::kCopy, thunk_info), source_address_(source_address), destination_buffer_(destination_buffer), mem_size_(mem_size) {} @@ -34,16 +33,15 @@ Status HostToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase destination_data = params.buffer_allocations->GetDeviceAddress(destination_buffer_); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_); return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( - const BufferAllocation::Slice& source_buffer, - const BufferAllocation::Slice& destination_buffer, uint64 mem_size, - const HloInstruction* hlo_instruction) - : Thunk(Kind::kCopy, hlo_instruction), + ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size) + : Thunk(Kind::kCopy, thunk_info), source_buffer_(source_buffer), destination_buffer_(destination_buffer), mem_size_(mem_size) {} @@ -54,7 +52,7 @@ Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase source_data = params.buffer_allocations->GetDeviceAddress(source_buffer_); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); params.stream->ThenMemcpy(&destination_data, source_data, mem_size_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index c6b7757de6e..ada016c900f 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -33,9 +33,9 @@ class HostToDeviceCopyThunk : public Thunk { // Constructs a CopyThunk that copies host data from `source_address` to the // device buffer `destination_buffer`. `mem_size` is the size of the data in // bytes. - HostToDeviceCopyThunk(const void* source_address, + HostToDeviceCopyThunk(ThunkInfo thunk_info, const void* source_address, const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction); + uint64 mem_size); HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; @@ -54,10 +54,10 @@ class DeviceToDeviceCopyThunk : public Thunk { // Constructs a CopyThunk that copies host data from `source_buffer` to the // device buffer `destination_buffer`. `mem_size` is the size of the data in // bytes. - DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer, + DeviceToDeviceCopyThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& source_buffer, const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, - const HloInstruction* hlo_instruction); + uint64 mem_size); DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index f785186b05b..36f415d9d89 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -92,12 +92,12 @@ void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { } // namespace CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( - const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, - const BufferAllocation::Slice& output, const HloInstruction* hlo) - : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo), + const BufferAllocation::Slice& output) + : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info), operand_(operand), scale_(scale), offset_(offset), @@ -106,6 +106,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( epsilon_(epsilon), feature_index_(feature_index), output_(output) { + const auto* hlo = hlo_instruction(); CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardInferenceCallTarget); @@ -118,7 +119,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( const ExecuteParams& params) { auto& buffer_allocations = *params.buffer_allocations; auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); se::DeviceMemoryBase output_base = buffer_allocations.GetDeviceAddress(output_); se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_); @@ -139,14 +140,14 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( } CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( - const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, - const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) - : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo), + const BufferAllocation::Slice& output_tuple) + : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), operand_(operand), scale_(scale), offset_(offset), @@ -156,6 +157,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( output_mean_(output_mean), output_inv_stddev_(output_inv_stddev), output_tuple_(output_tuple) { + const auto* hlo = hlo_instruction(); CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); @@ -178,7 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( se::DeviceMemory null_device_ptr(nullptr); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( hlo_instruction(), operand, output_data, output_mean, output_inv_stddev, @@ -203,15 +205,15 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( } CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( - const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, const BufferAllocation::Slice& grad_output, float epsilon, int64 feature_index, const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, - const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) - : Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo), + const BufferAllocation::Slice& output_tuple) + : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), operand_(operand), scale_(scale), mean_(mean), @@ -223,6 +225,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( output_grad_scale_(output_grad_scale), output_grad_offset_(output_grad_offset), output_tuple_(output_tuple) { + const auto* hlo = hlo_instruction(); CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); @@ -247,7 +250,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(output_grad_offset_)); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); se::Stream* stream = params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( hlo_instruction(), operand, output_grad_data, grad_output, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index a69e37018c3..5897435a58f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -46,14 +46,14 @@ namespace gpu { class CudnnBatchNormForwardInferenceThunk : public Thunk { public: - CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand, + CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, - const BufferAllocation::Slice& output, - const HloInstruction* hlo); + const BufferAllocation::Slice& output); CudnnBatchNormForwardInferenceThunk( const CudnnBatchNormForwardInferenceThunk&) = delete; @@ -76,13 +76,13 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { class CudnnBatchNormForwardTrainingThunk : public Thunk { public: CudnnBatchNormForwardTrainingThunk( - const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, - const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo); + const BufferAllocation::Slice& output_tuple); CudnnBatchNormForwardTrainingThunk( const CudnnBatchNormForwardTrainingThunk&) = delete; @@ -105,7 +105,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { class CudnnBatchNormBackwardThunk : public Thunk { public: - CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand, + CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, @@ -114,8 +115,7 @@ class CudnnBatchNormBackwardThunk : public Thunk { const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, - const BufferAllocation::Slice& output_tuple, - const HloInstruction* hlo); + const BufferAllocation::Slice& output_tuple); CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index 4687acd3cff..16a1f923c91 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -22,15 +22,15 @@ namespace xla { namespace gpu { CustomCallThunk::CustomCallThunk( - void* call_target, + ThunkInfo thunk_info, void* call_target, std::vector> operand_slices, - ShapeTree result_slices, std::string opaque, - const HloInstruction* instr) - : Thunk(Thunk::kCustomCall, instr), + ShapeTree result_slices, std::string opaque) + : Thunk(Thunk::kCustomCall, thunk_info), call_target_(call_target), operand_slices_(std::move(operand_slices)), result_slices_(std::move(result_slices)), opaque_(std::move(opaque)) { + const HloInstruction* instr = hlo_instruction(); CHECK_EQ(instr->operand_count(), operand_slices_.size()); for (int64 i = 0; i < instr->operand_count(); ++i) { const auto& s1 = operand_slices_[i].shape(); diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h index ff2d4c69a71..72175daf3dd 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -39,10 +39,9 @@ namespace gpu { class CustomCallThunk : public Thunk { public: CustomCallThunk( - void* call_target, + ThunkInfo thunk_info, void* call_target, std::vector> operand_slices, - ShapeTree result_slices, std::string opaque, - const HloInstruction* instr); + ShapeTree result_slices, std::string opaque); Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 998a3ccb4ee..318b8aff176 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -42,9 +42,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { struct NcclAllReduceThunk::AuxData {}; NcclAllReduceThunk::NcclAllReduceThunk( - int64 replica_count, std::vector buffers, - const HloInstruction* all_reduce) - : Thunk(Thunk::kNcclAllReduce, all_reduce), + ThunkInfo thunk_info, int64 replica_count, + std::vector buffers) + : Thunk(Thunk::kNcclAllReduce, thunk_info), replica_count_(replica_count), buffers_(std::move(buffers)) {} diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 9d6be3c78ea..d3800c7e6b4 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -98,12 +98,12 @@ string FftTypeToString(se::fft::Type type) { } // namespace -FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, +FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type, + absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, - const Shape& input_shape, const Shape& output_shape, - const HloInstruction* hlo) - : Thunk(Kind::kFft, hlo), + const Shape& input_shape, const Shape& output_shape) + : Thunk(Kind::kFft, thunk_info), fft_type_( FftTypeToSeType(fft_type, input_shape.element_type() == F64 || input_shape.element_type() == C128)), @@ -127,7 +127,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.memory_allocator()); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); if (fft_plan_ == nullptr) { const int64 fft_rank = fft_length_.size(); CHECK_LE(fft_rank, 3); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 95186c7f219..bde271216b5 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -62,11 +62,11 @@ class FftThunk : public Thunk { public: // Constructs a thunk for launching an FFT on a stream. // Semantics of null hlo_instruction argument are as in Thunk. - FftThunk(FftType fft_type, absl::Span fft_length, + FftThunk(ThunkInfo thunk_info, FftType fft_type, + absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, - const Shape& input_shape, const Shape& output_shape, - const HloInstruction* hlo); + const Shape& input_shape, const Shape& output_shape); FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_ FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 0a97f668b38..7fc3bdd4436 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -23,16 +23,15 @@ limitations under the License. namespace xla { namespace gpu { -ForThunk::ForThunk(const int64 loop_limit, - std::unique_ptr body_thunk_sequence, - const HloInstruction* hlo) - : Thunk(Kind::kWhile, hlo), +ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, + std::unique_ptr body_thunk_sequence) + : Thunk(Kind::kWhile, thunk_info), loop_limit_(loop_limit), body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. - std::move(*body_thunk_sequence), nullptr)) {} + ThunkInfo(), std::move(*body_thunk_sequence))) {} void ForThunk::ComputeAnnotations() { Thunk::ComputeAnnotations(); @@ -49,7 +48,7 @@ Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " << (hlo_instruction() ? hlo_instruction()->ToString() : ""); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); for (int64 i = 0; i < loop_limit_; ++i) { params.profiler->StartHloComputation(); // Invoke loop body thunk sequence. diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 57402f70627..77a89ea6023 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -31,9 +31,8 @@ namespace gpu { // ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'. class ForThunk : public Thunk { public: - ForThunk(const int64 loop_limit, - std::unique_ptr body_thunk_sequence, - const HloInstruction* hlo); + ForThunk(ThunkInfo thunk_info, const int64 loop_limit, + std::unique_ptr body_thunk_sequence); ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 8316cb7d12d..0320496ea98 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -132,6 +132,7 @@ static StatusOr> DoUncachedGemmAutotune( CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer, stream, /*implements_whole_instruction=*/true, + /*profile_index=*/-1, /*profiler=*/nullptr, /*profile_result=*/&profile_result, algorithm) .ok()); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index d52e5410dab..561dfbe3137 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -33,13 +33,13 @@ limitations under the License. namespace xla { namespace gpu { -GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer, +GemmThunk::GemmThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice &lhs_buffer, const BufferAllocation::Slice &rhs_buffer, const BufferAllocation::Slice &output_buffer, bool implements_whole_instruction, - const HloInstruction *hlo_instruction, const GemmBackendConfig &backend_config) - : Thunk(Kind::kGemm, hlo_instruction), + : Thunk(Kind::kGemm, thunk_info), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), output_buffer_(output_buffer), @@ -57,7 +57,7 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { se::DeviceMemoryBase output_data = get_device_address(output_buffer_); return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data, output_data, params.stream, implements_whole_instruction_, - params.profiler); + profile_index(), params.profiler); } // This struct contains the metadata of a matrix, e.g., its base address and @@ -160,6 +160,7 @@ Status RunGemm(const HloInstruction *gemm, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::Stream *stream, bool implements_whole_instruction, + absl::optional profile_index, HloExecutionProfiler *profiler, se::blas::ProfileResult *profile_result, absl::optional algorithm) { @@ -240,7 +241,7 @@ Status RunGemm(const HloInstruction *gemm, rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); std::unique_ptr op_profiler = profiler ? profiler->MakeScopedInstructionProfiler( - implements_whole_instruction ? gemm : nullptr) + implements_whole_instruction ? profile_index : -1) : nullptr; if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index b44cc40d295..2bccb7b3572 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -39,11 +39,10 @@ class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs rhs) * alpha" using // BLAS gemm (alpha is stored in the instruction GemmBackendConfig). - GemmThunk(const BufferAllocation::Slice& lhs_buffer, + GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, bool implements_whole_instruction, - const HloInstruction* hlo_instruction, const GemmBackendConfig& backend_config); GemmThunk(const GemmThunk&) = delete; @@ -72,7 +71,8 @@ Status RunGemm( const HloInstruction* gemm, const GemmBackendConfig& backend_config, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::Stream* stream, - bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr, + bool implements_whole_instruction, absl::optional profile_index, + HloExecutionProfiler* profiler = nullptr, se::blas::ProfileResult* profile_result = nullptr, absl::optional algorithm = absl::nullopt); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 758bba90bd2..3dcdb4c90eb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -472,7 +472,8 @@ static Status CompileModuleToLlvmIrImpl( const std::string& platform_name, GpuDeviceInfo gpu_device_info, absl::optional cuda_compute_capability, const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function, - int pointer_size, std::unique_ptr* llvm_module, + int pointer_size, const HloProfileIndexMap* profile_index_map, + std::unique_ptr* llvm_module, std::unique_ptr* buffer_assignment, std::unique_ptr* thunk_schedule) { *llvm_module = absl::make_unique("", *llvm_context); @@ -509,7 +510,7 @@ static Status CompileModuleToLlvmIrImpl( IrEmitterContext ir_emitter_context( hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, - cuda_compute_capability, llvm_module->get()); + cuda_compute_capability, profile_index_map, llvm_module->get()); HloComputation* entry_computation = hlo_module->entry_computation(); IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, @@ -532,10 +533,14 @@ static Status CompileModuleToLlvmIrImpl( // not all explicitly checked, but at least we can document them here: // * The entry HloComputation shall not have dead code (all reachable from // ROOT). - // * For each visit of HloInstruction, either none or one Thunk will be - // returned. + // * The visited instructions are all instructions in the entry + // computation. + // * For each visit of these HloInstructions, either none or one Thunk + // will be returned. // * If there is a thunk returned, thunk->hlo_instruction() equals the // input HloInstruction*. + // * A returned thunk may contain other sub-thunks. A sub-thunk may or may + // not have an associated hlo_instruction(). TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString(); if (!thunks->empty()) { auto thunk = std::move(thunks->front()); @@ -603,6 +608,25 @@ StatusOr> GpuCompiler::RunBackend( return cuda_compute_capability; }(); + std::unique_ptr profile_index_map; + std::unique_ptr profile_printer; + + if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { + HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + cost_analysis.set_bytes_per_second( + stream_exec->GetDeviceDescription().memory_bandwidth()); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); + VLOG(1) << "HLO memory read+written: " + << tensorflow::strings::HumanReadableNumBytes( + cost_analysis.bytes_accessed()); + if (module->config().hlo_profiling_enabled()) { + profile_index_map = absl::make_unique(*module); + profile_printer = + CreateHloProfilePrinterData(*profile_index_map, cost_analysis, + module->entry_computation()->name()); + } + } + std::unique_ptr llvm_module; std::unique_ptr buffer_assignment; std::unique_ptr thunk_schedule; @@ -610,8 +634,8 @@ StatusOr> GpuCompiler::RunBackend( TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( module.get(), &llvm_context, target_triple_, data_layout_, stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, - GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment, - &thunk_schedule)); + GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module, + &buffer_assignment, &thunk_schedule)); if (user_pre_optimization_hook_) { user_pre_optimization_hook_(*llvm_module); @@ -653,25 +677,6 @@ StatusOr> GpuCompiler::RunBackend( thunk_schedule->ToString()); } - std::unique_ptr profile_index_map; - std::unique_ptr profile_printer; - - if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { - HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); - cost_analysis.set_bytes_per_second( - stream_exec->GetDeviceDescription().memory_bandwidth()); - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - VLOG(1) << "HLO memory read+written: " - << tensorflow::strings::HumanReadableNumBytes( - cost_analysis.bytes_accessed()); - if (module->config().hlo_profiling_enabled()) { - profile_index_map = absl::make_unique(*module); - profile_printer = - CreateHloProfilePrinterData(*profile_index_map, cost_analysis, - module->entry_computation()->name()); - } - } - auto* gpu_executable = new GpuExecutable( backend_result.first, backend_result.second, gpu_version, std::move(thunk_schedule), std::move(module), @@ -709,7 +714,8 @@ StatusOr> CompileModuleToLlvmIr( TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( hlo_module, llvm_context, target_triple, data_layout, platform_name, gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, - pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule)); + pointer_size, /*profile_index_map=*/nullptr, &llvm_module, + &buffer_assignment, &thunk_schedule)); return llvm_module; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index b9c21e8edb2..1f83ec71984 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -97,26 +96,24 @@ void HloExecutionProfiler::StartHloInstruction() { } } -void HloExecutionProfiler::FinishHloInstruction( - const HloInstruction* hlo_instruction) { +void HloExecutionProfiler::FinishHloInstruction(size_t index) { if (do_profile_) { - hlo_instructions_.erase(hlo_instruction); - profile_->SetCyclesTakenBy( - hlo_instruction, - GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + indices_.erase(index); + profile_->SetCyclesTakenBy(index, GetCyclesTaken(&timers_, sub_streams_, + stream_, clock_rate_ghz_)); } } std::unique_ptr HloExecutionProfiler::MakeScopedInstructionProfiler( - const HloInstruction* hlo_instruction) { - if (do_profile_ && hlo_instruction != nullptr) { + absl::optional index) { + if (do_profile_ && index.has_value()) { // Make sure that we are not already measuring the time for the same - // 'hlo_instruction'. - CHECK(hlo_instructions_.insert(hlo_instruction).second) - << hlo_instruction->name(); + // instruction. + // TODO(timshen): provide more useful printout. + CHECK(indices_.insert(*index).second) << *index; } - return absl::make_unique(this, hlo_instruction); + return absl::make_unique(this, index); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h index 80cde75f2bb..1189143e3f9 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -58,14 +57,17 @@ class HloExecutionProfiler { void StartHloInstruction(); // If profiling is enabled, stops the per-operation timer and records the time - // that the hlo_instruction took to execute in the profile. - void FinishHloInstruction(const HloInstruction* hlo_instruction); + // that at `profile_index`. Profile indices can be looked up from + // HloProfileIndexMap. + void FinishHloInstruction(size_t profile_index); // Returns a ScopedInstructionProfiler and triggers a call to // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes // out of scope, it triggers a call to FinishHloInstruction(). + // + // If profile_index < 0, it results in a no-op. std::unique_ptr MakeScopedInstructionProfiler( - const HloInstruction* hlo_instruction); + absl::optional profile_index); private: const bool do_profile_; @@ -77,7 +79,7 @@ class HloExecutionProfiler { std::stack> timers_; // Contains the HLO instructions for which we are currently measuring the // time. - std::unordered_set hlo_instructions_; + std::unordered_set indices_; bool finished_execution_ = false; }; @@ -87,21 +89,21 @@ class HloExecutionProfiler { class ScopedInstructionProfiler { public: ScopedInstructionProfiler(HloExecutionProfiler* profiler, - const HloInstruction* hlo_instruction) - : profiler_(profiler), hlo_instruction_(hlo_instruction) { - if (hlo_instruction != nullptr) { + absl::optional index) + : profiler_(profiler), index_(index) { + if (index_.has_value()) { profiler->StartHloInstruction(); } } ~ScopedInstructionProfiler() { - if (hlo_instruction_ != nullptr) { - profiler_->FinishHloInstruction(hlo_instruction_); + if (index_.has_value()) { + profiler_->FinishHloInstruction(*index_); } } private: HloExecutionProfiler* profiler_; - const HloInstruction* hlo_instruction_; + absl::optional index_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 1b4cd45d1cc..43cc5f5a2ae 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -23,9 +23,9 @@ namespace xla { namespace gpu { InfeedThunk::InfeedThunk( - const ShapeTree& infeed_slices, - const HloInstruction* hlo_instruction) - : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} + ThunkInfo thunk_info, + const ShapeTree& infeed_slices) + : Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; @@ -34,7 +34,7 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); ShapeTree infeed_buffers = GetOrCreateInfeedManager()->BlockingGetNextDestination(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index f04ac16fb08..ec33235c466 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -34,8 +34,8 @@ class InfeedThunk : public Thunk { public: // Constructs a InfeedThunk that copies data from the on-device // infeed queue into the buffers in the given shape tree. - InfeedThunk(const ShapeTree& infeed_slices, - const HloInstruction* hlo_instruction); + InfeedThunk(ThunkInfo thunk_info, + const ShapeTree& infeed_slices); InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 7678bb23184..9c43f80dc60 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" namespace xla { @@ -33,12 +34,13 @@ class IrEmitterContext { const HloModule* hlo_module, const BufferAssignment* buffer_assignment, std::string platform_name, GpuDeviceInfo gpu_device_info, absl::optional cuda_compute_capability, - llvm::Module* llvm_module) + const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), cuda_compute_capability_(cuda_compute_capability), + profile_index_map_(profile_index_map), llvm_module_(llvm_module) {} // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; @@ -54,6 +56,7 @@ class IrEmitterContext { absl::optional cuda_compute_capability() const { return cuda_compute_capability_; } + const HloProfileIndexMap* profile_index_map() { return profile_index_map_; } llvm::Module* llvm_module() { return llvm_module_; } NameUniquer* name_uniquer() { return &name_uniquer_; } @@ -63,6 +66,7 @@ class IrEmitterContext { std::string platform_name_; GpuDeviceInfo gpu_device_info_; absl::optional cuda_compute_capability_; + const HloProfileIndexMap* profile_index_map_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 4c4ae47cd69..a232bf7fce5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -652,8 +652,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { /*updates_gen=*/ scatter_fused_emitter.GetGenerator(root->operand(2)))); } - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), fusion)); + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(fusion), std::move(thunks))); return Status::OK(); } // In the case of root tuple, it can be either reduce or slice input @@ -739,10 +739,11 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { auto destination_buffer = GetAllocationSlice(*copy); if (operand_buffer != destination_buffer) { AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(copy), /*source_address=*/operand_buffer, /*destination_buffer=*/destination_buffer, /*mem_size=*/ - ByteSizeOf(copy->operand(0)->shape()), copy)); + ByteSizeOf(copy->operand(0)->shape()))); } return Status::OK(); } @@ -816,7 +817,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } AddThunkToThunkSequence(absl::make_unique( - tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); + GetThunkInfo(tuple), tuple_element_buffers, + GetAllocationSlice(*tuple))); return Status::OK(); } AddThunkToThunkSequence( @@ -848,7 +850,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); std::unique_ptr select_and_scatter_thunk = - absl::make_unique(std::move(thunks), select_and_scatter); + absl::make_unique(GetThunkInfo(select_and_scatter), + std::move(thunks)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1082,10 +1085,10 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { auto destination_buffer = GetAllocationSlice(*scatter); if (operand_buffer != destination_buffer) { thunks.push_back(absl::make_unique( + Thunk::ThunkInfo(), /*source_address=*/operand_buffer, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), - /*hlo_instruction=*/nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()))); } thunks.push_back( @@ -1109,8 +1112,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), scatter)); + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(scatter), std::move(thunks))); } return Status::OK(); @@ -1282,10 +1285,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. thunks.push_back(absl::make_unique( + Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()), - nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); } } @@ -1419,8 +1422,8 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), sort)); + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(sort), std::move(thunks))); if (sort->operand_count() > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. @@ -1438,14 +1441,15 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { } Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) { - AddThunkToThunkSequence( - absl::make_unique(GetAllocationSlice(*hlo), hlo)); + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(hlo), GetAllocationSlice(*hlo))); return Status::OK(); } Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { AddThunkToThunkSequence(absl::make_unique( - GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo)); + GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), + GetAllocationSlice(*hlo))); return Status::OK(); } @@ -1478,15 +1482,16 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { tuple_element_buffers.push_back(buffers[i].destination_buffer); } auto all_reduce_thunk = absl::make_unique( + GetThunkInfo(crs), /*replica_count=*/hlo_module_config_.replica_count(), - /*buffers=*/std::move(buffers), crs); + /*buffers=*/std::move(buffers)); if (crs->shape().IsTuple()) { std::vector> thunks; thunks.push_back(std::move(all_reduce_thunk)); thunks.push_back(absl::make_unique( - tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), crs)); + Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs))); + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(crs), std::move(thunks))); } else { AddThunkToThunkSequence(std::move(all_reduce_thunk)); } @@ -1520,9 +1525,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { CHECK(crs->operand(0)->shape().IsArray()) << "Operands to all-reduce must be arrays: " << crs->ToString(); AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(crs), /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), - /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()))); return Status::OK(); } @@ -1535,16 +1541,17 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { .GetUniqueSlice(crs, {i}) .ValueOrDie()); thunks.push_back(absl::make_unique( + Thunk::ThunkInfo(), /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), - /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()))); } // Output a tuple of the buffers above. thunks.push_back(absl::make_unique( - tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); + Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs))); AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), crs)); + absl::make_unique(GetThunkInfo(crs), std::move(thunks))); return Status::OK(); } @@ -1787,8 +1794,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( } return absl::make_unique( - non_constant_buffers, std::string(kernel->getName()), - implements_whole_instruction ? inst : nullptr); + implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(), + non_constant_buffers, std::string(kernel->getName())); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -1838,8 +1845,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( absl::Span literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {absl::make_unique(GetAllocationSlice(*hlo, index), - nullptr)}; + return {absl::make_unique(Thunk::ThunkInfo(), + GetAllocationSlice(*hlo, index))}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -1857,7 +1864,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); return {absl::make_unique( - pattern32, GetAllocationSlice(*hlo, index), nullptr)}; + Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -1868,7 +1875,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); return {absl::make_unique( - word, GetAllocationSlice(*hlo, index), nullptr)}; + Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))}; } } @@ -2014,9 +2021,10 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( TF_CHECK_OK(body->Accept(&ir_emitter_body)); return absl::make_unique( + GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), - ir_emitter_body.ConsumeThunkSequence(), hlo); + ir_emitter_body.ConsumeThunkSequence()); } std::unique_ptr IrEmitterUnnested::BuildForThunk( @@ -2031,8 +2039,8 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return absl::make_unique( - loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique(GetThunkInfo(hlo), loop_limit, + ir_emitter_body.ConsumeThunkSequence()); } std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( @@ -2054,8 +2062,8 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( } return absl::make_unique( - GetAllocationSlice(*hlo->operand(0)), branch_operands, - std::move(branch_thunks), hlo); + GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, + std::move(branch_thunks)); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( @@ -3589,8 +3597,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( ir_emitter_context_->llvm_module()); thunks.push_back(std::move(kernel_thunk)); - auto sequential_thunk = - absl::make_unique(std::move(thunks), unnested_hlo); + auto sequential_thunk = absl::make_unique( + GetThunkInfo(unnested_hlo), std::move(thunks)); AddThunkToThunkSequence(std::move(sequential_thunk)); return Status::OK(); @@ -3757,5 +3765,15 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( return emit_status; } +Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo( + const HloInstruction* hlo) const { + auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo); + if (const auto* index_map = ir_emitter_context_->profile_index_map()) { + info.profile_index.emplace( + static_cast(index_map->GetProfileIndexFor(*hlo))); + } + return info; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 1be3b8dbd26..019fcdf21db 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -548,6 +548,8 @@ class IrEmitterUnnested : public IrEmitter, // Returns the last generated thunk. Thunk* LastThunk() const { return thunk_sequence_.back().get(); } + Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override; + // The thunk sequence this IrEmitter generates for the input computation. ThunkSequence thunk_sequence_; diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 881c8e00779..19fef37db7e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -33,10 +33,10 @@ limitations under the License. namespace xla { namespace gpu { -KernelThunk::KernelThunk(absl::Span args, - const string& kernel_name, - const HloInstruction* hlo_instruction) - : Thunk(Kind::kKernel, hlo_instruction), +KernelThunk::KernelThunk(ThunkInfo thunk_info, + absl::Span args, + const string& kernel_name) + : Thunk(Kind::kKernel, thunk_info), args_(args.begin(), args.end()), kernel_name_(kernel_name) {} @@ -114,7 +114,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { } auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions.threads_per_block(), launch_dimensions.block_count(), params.stream); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 8f1debe80e8..0717ccd5ac1 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -47,8 +47,9 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(absl::Span args, - const string& kernel_name, const HloInstruction* hlo_instruction); + KernelThunk(ThunkInfo thunk_info, + absl::Span args, + const string& kernel_name); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; ~KernelThunk() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index 7835b04ace6..b4762d07276 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -25,7 +25,7 @@ Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase dest_data = params.buffer_allocations->GetDeviceAddress(dest_); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); params.stream->ThenMemZero(&dest_data, dest_data.size()); return Status::OK(); } @@ -34,7 +34,7 @@ Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase dest_data = params.buffer_allocations->GetDeviceAddress(dest_); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); params.stream->ThenMemset32(&dest_data, value_, dest_data.size()); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index 347e05d9dd9..8a1890a0769 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -32,9 +32,9 @@ namespace gpu { // Thunk that zeroes out a given chunk of memory. class MemzeroThunk : public Thunk { public: - explicit MemzeroThunk(const BufferAllocation::Slice& dest, - const HloInstruction* hlo) - : Thunk(Kind::kMemzero, hlo), dest_(dest) {} + explicit MemzeroThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& dest) + : Thunk(Kind::kMemzero, thunk_info), dest_(dest) {} Status ExecuteOnStream(const ExecuteParams& params) override; @@ -46,10 +46,11 @@ class MemzeroThunk : public Thunk { // destination chunk must have size divisible by 32 bits. class Memset32BitValueThunk : public Thunk { public: - explicit Memset32BitValueThunk(uint32 value, - const BufferAllocation::Slice& dest, - const HloInstruction* hlo) - : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} + explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32 value, + const BufferAllocation::Slice& dest) + : Thunk(Kind::kMemset32BitValue, thunk_info), + value_(value), + dest_(dest) {} Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 97eceb489f9..755413beeee 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -541,9 +541,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { } NcclAllReduceThunk::NcclAllReduceThunk( - int64 replica_count, std::vector buffers, - const HloInstruction* all_reduce) - : Thunk(Thunk::kNcclAllReduce, all_reduce), + ThunkInfo thunk_info, int64 replica_count, + std::vector buffers) + : Thunk(Thunk::kNcclAllReduce, thunk_info), replica_count_(replica_count), buffers_(std::move(buffers)), aux_data_(absl::make_unique()) { @@ -555,7 +555,7 @@ NcclAllReduceThunk::NcclAllReduceThunk( Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(1) << "Starting NcclAllReduceThunk."; auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); auto* instr = Cast(hlo_instruction()); int64 local_device_ordinal = params.stream->parent()->device_ordinal(); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 90091ed2c7b..1df4f0805a6 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -56,8 +56,8 @@ class NcclAllReduceThunk : public Thunk { BufferAllocation::Slice source_buffer; BufferAllocation::Slice destination_buffer; }; - NcclAllReduceThunk(int64 replica_count, std::vector buffers, - const HloInstruction* all_reduce); + NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count, + std::vector buffers); ~NcclAllReduceThunk() override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 25ab1b54f07..104366fd78c 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -23,9 +23,9 @@ limitations under the License. namespace xla { namespace gpu { -OutfeedThunk::OutfeedThunk(ShapeTree outfeed_slices, - const HloInstruction* hlo_instruction) - : Thunk(Kind::kOutfeed, hlo_instruction), +OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, + ShapeTree outfeed_slices) + : Thunk(Kind::kOutfeed, thunk_info), outfeed_slices_(std::move(outfeed_slices)) {} Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { @@ -35,7 +35,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); ShapeTree>* outfeed_buffers = outfeed_manager->BlockingGetNextDestination(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index ec18ad8476c..e99174e3c6c 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -32,8 +32,8 @@ class OutfeedThunk : public Thunk { public: // Constructs a OutfeedThunk that copies data to the host-side // outfeed queue from the buffers in the given shape tree. - OutfeedThunk(ShapeTree outfeed_slices, - const HloInstruction* hlo_instruction); + OutfeedThunk(ThunkInfo thunk_info, + ShapeTree outfeed_slices); OutfeedThunk(const OutfeedThunk&) = delete; OutfeedThunk& operator=(const OutfeedThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc index a2178ba3faa..b6792bb7a26 100644 --- a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.cc @@ -18,13 +18,13 @@ limitations under the License. namespace xla { namespace gpu { -ReplicaIdThunk::ReplicaIdThunk(const BufferAllocation::Slice& dest, - const HloInstruction* instr) - : Thunk(Kind::kReplicaId, instr), dest_(dest) {} +ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& dest) + : Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {} Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_); TF_ASSIGN_OR_RETURN(int replica_id, diff --git a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h index 816931720e6..80aee41da39 100644 --- a/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/replica_id_thunk.h @@ -26,8 +26,7 @@ namespace gpu { // Thunk that implements the ReplicaId HLO. class ReplicaIdThunk : public Thunk { public: - ReplicaIdThunk(const BufferAllocation::Slice& dest, - const HloInstruction* instr); + ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest); Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 025ca60ef0c..15cf2493549 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -24,9 +24,9 @@ namespace gpu { using ::tensorflow::profiler::ScopedAnnotation; -SequentialThunk::SequentialThunk(std::vector> thunks, - const HloInstruction* hlo) - : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} +SequentialThunk::SequentialThunk(ThunkInfo thunk_info, + std::vector> thunks) + : Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {} void SequentialThunk::ComputeAnnotations() { for (const auto& thunk : thunks_) { @@ -44,7 +44,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); for (const auto& thunk : thunks_) { ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 3abb82c0b66..127c5bcf734 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -32,8 +32,8 @@ namespace gpu { // require multiple kernel launches or library calls. class SequentialThunk : public Thunk { public: - SequentialThunk(std::vector> thunks, - const HloInstruction* hlo); + SequentialThunk(ThunkInfo thunk_info, + std::vector> thunks); SequentialThunk(const SequentialThunk&) = delete; SequentialThunk& operator=(const SequentialThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index d0477d374af..0a5382291c9 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -68,13 +68,21 @@ class Thunk { kWhile, }; + struct ThunkInfo { + const HloInstruction* hlo_instruction = nullptr; + absl::optional profile_index; + // TODO(timshen): Remove hlo_instruction and add name(), + // profile_annotation() here. + }; + // The hlo_instruction argument is meant to be the instruction this thunk was // generated from, but Thunk never uses this argument other than to save it // to Thunk::hlo_instruction, so it can be null. - explicit Thunk(Kind kind, const HloInstruction* hlo_instruction) + explicit Thunk(Kind kind, ThunkInfo thunk_info) : kind_(kind), - hlo_instruction_(hlo_instruction), - name_(hlo_instruction_ ? hlo_instruction_->name() : "") {} + hlo_instruction_(thunk_info.hlo_instruction), + name_(hlo_instruction_ ? hlo_instruction_->name() : ""), + profile_index_(thunk_info.profile_index) {} virtual ~Thunk() {} Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; @@ -128,6 +136,8 @@ class Thunk { protected: const HloInstruction* hlo_instruction() const { return hlo_instruction_; } + absl::optional profile_index() const { return profile_index_; } + const HloModuleConfig& GetModuleConfig() const { return hlo_instruction()->GetModule()->config(); } @@ -146,8 +156,12 @@ class Thunk { private: Kind kind_; + + // Will be removed in the future, as Thunk is migrating away from the + // monolithic HloInstruction. const HloInstruction* hlo_instruction_; std::string name_; + absl::optional profile_index_; string profile_annotation_; }; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 5da7aeaa2d1..089d70d658f 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -40,11 +40,11 @@ namespace gpu { std::unique_ptr ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); return absl::make_unique( - inst->fft_type(), inst->fft_length(), + context_->GetThunkInfo(inst), inst->fft_type(), inst->fft_length(), /*input_buffer=*/GetAllocationSlice(*operand), /*output_buffer=*/GetAllocationSlice(*inst), /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + /*output_shape=*/inst->shape()); } std::unique_ptr ThunkEmitter::BuildTriangularSolveThunk( @@ -63,11 +63,11 @@ std::unique_ptr ThunkEmitter::BuildTriangularSolveThunk( : n * n * elem_size; int64 b_batch_stride = m * n * elem_size; return absl::make_unique( - inst->triangular_solve_options(), + context_->GetThunkInfo(inst), inst->triangular_solve_options(), /*a_input_buffer=*/GetAllocationSlice(*a), /*b_input_buffer=*/GetAllocationSlice(*inst), inst->shape().element_type(), batch_size, m, n, a_batch_stride, - b_batch_stride, inst); + b_batch_stride); } std::unique_ptr ThunkEmitter::BuildGemmThunk( @@ -86,24 +86,27 @@ std::unique_ptr ThunkEmitter::BuildGemmThunk( if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { std::vector> thunks; thunks.push_back(absl::make_unique( + Thunk::ThunkInfo(), /*source_buffer=*/GetAllocationSlice(*bias), /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr)); + /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()))); thunks.push_back(absl::make_unique( + context_->GetThunkInfo(inst), GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/false, inst, - std::move(gemm_config))); - return absl::make_unique(std::move(thunks), inst); + /*implements_whole_instruction=*/false, std::move(gemm_config))); + return absl::make_unique(context_->GetThunkInfo(inst), + std::move(thunks)); } } return absl::make_unique( + context_->GetThunkInfo(inst), GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/true, inst, std::move(gemm_config)); + /*implements_whole_instruction=*/true, std::move(gemm_config)); } std::unique_ptr ThunkEmitter::BuildInfeedThunk( @@ -115,7 +118,7 @@ std::unique_ptr ThunkEmitter::BuildInfeedThunk( [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { *slice = GetAllocationSlice(*inst, index); }); - return absl::make_unique(slices, inst); + return absl::make_unique(context_->GetThunkInfo(inst), slices); } std::unique_ptr ThunkEmitter::BuildOutfeedThunk( @@ -130,7 +133,8 @@ std::unique_ptr ThunkEmitter::BuildOutfeedThunk( *slice = status_or_slice.ValueOrDie(); } }); - return absl::make_unique(std::move(slices), inst); + return absl::make_unique(context_->GetThunkInfo(inst), + std::move(slices)); } Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { @@ -152,6 +156,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { AddThunkToThunkSequence( absl::make_unique( + context_->GetThunkInfo(custom_call), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -159,8 +164,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*variance=*/GetAllocationSlice(*custom_call->operand(4)), /*epsilon=*/epsilon_value, /*feature_index=*/feature_index_value, - /*output=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + /*output=*/GetAllocationSlice(*custom_call))); return Status::OK(); } @@ -181,6 +185,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); AddThunkToThunkSequence( absl::make_unique( + context_->GetThunkInfo(custom_call), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -189,8 +194,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*output_data=*/output_data, /*output_mean=*/output_mean, /*output_inv_stddev=*/output_inv_stddev, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + /*output_tuple=*/GetAllocationSlice(*custom_call))); return Status::OK(); } @@ -209,6 +213,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); AddThunkToThunkSequence(absl::make_unique( + context_->GetThunkInfo(custom_call), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*mean=*/GetAllocationSlice(*custom_call->operand(2)), @@ -219,8 +224,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*output_grad_data=*/output_grad_data, /*output_grad_scale=*/output_grad_scale, /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + /*output_tuple=*/GetAllocationSlice(*custom_call))); return Status::OK(); } @@ -235,7 +239,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto scratch_slice = GetAllocationSlice(*custom_call, {1}); AddThunkToThunkSequence(absl::make_unique( - Cast(custom_call), std::move(operand_slices), + context_->GetThunkInfo(custom_call), std::move(operand_slices), conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } @@ -269,22 +273,23 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { if (operand_buffer != a_buffer) { thunks.push_back(absl::make_unique( + context_->GetThunkInfo(custom_call), /*source_address=*/operand_buffer, /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call)); + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); } thunks.push_back(absl::make_unique( - options, a_buffer, workspace_buffer, info_buffer, - custom_call->operand(0)->shape().element_type(), batch_size, n, - custom_call)); + context_->GetThunkInfo(custom_call), options, a_buffer, + workspace_buffer, info_buffer, + custom_call->operand(0)->shape().element_type(), batch_size, n)); // Elide the sequential thunk if there's no copy. if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), custom_call)); + AddThunkToThunkSequence(absl::make_unique( + context_->GetThunkInfo(custom_call), std::move(thunks))); } return Status::OK(); @@ -311,8 +316,9 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { ShapeTree result_slices = get_slices_for_instr(custom_call); AddThunkToThunkSequence(absl::make_unique( - call_target, std::move(operand_slices), std::move(result_slices), - Cast(custom_call)->opaque(), custom_call)); + context_->GetThunkInfo(custom_call), call_target, + std::move(operand_slices), std::move(result_slices), + Cast(custom_call)->opaque())); return Status::OK(); } #endif @@ -347,9 +353,10 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { auto destination_buffer = GetAllocationSlice(*hlo); if (operand_buffer != destination_buffer) { thunks.push_back(absl::make_unique( + context_->GetThunkInfo(hlo), /*source_address=*/operand_buffer, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); + /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()))); } thunks.push_back(BuildTriangularSolveThunk(hlo)); @@ -358,8 +365,8 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { if (thunks.size() == 1) { AddThunkToThunkSequence(std::move(thunks[0])); } else { - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), hlo)); + AddThunkToThunkSequence(absl::make_unique( + context_->GetThunkInfo(hlo), std::move(thunks))); } return Status::OK(); } @@ -374,5 +381,12 @@ Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) { return Status::OK(); } +Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo( + const HloInstruction* hlo) const { + CHECK(hlo); + Thunk::ThunkInfo info; + info.hlo_instruction = hlo; + return info; +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h index f4ef87dac5a..16b11a4d5e2 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h @@ -36,6 +36,7 @@ class ThunkEmitter { const HloInstruction& hlo, const ShapeIndex& index) const = 0; virtual int64 ByteSizeOf(const Shape& shape) const = 0; virtual absl::string_view platform_name() const = 0; + virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const; virtual ~EmissionContext() = default; }; diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc index c5926890e8c..04b9ec2d525 100644 --- a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc @@ -32,12 +32,12 @@ namespace xla { namespace gpu { TriangularSolveThunk::TriangularSolveThunk( - const TriangularSolveOptions& options, + ThunkInfo thunk_info, const TriangularSolveOptions& options, const BufferAllocation::Slice& a_buffer, const BufferAllocation::Slice& b_buffer, PrimitiveType type, int64 batch_size, int64 m, int64 n, int64 a_batch_stride, - int64 b_batch_stride, const HloInstruction* hlo) - : Thunk(Kind::kTriangularSolve, hlo), + int64 b_batch_stride) + : Thunk(Kind::kTriangularSolve, thunk_info), uplo_(options.lower() ? se::blas::UpperLower::kLower : se::blas::UpperLower::kUpper), side_(options.left_side() ? se::blas::Side::kLeft diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h index 4f76e29a051..d5ad8e3ddc0 100644 --- a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h @@ -38,12 +38,12 @@ namespace gpu { // Thread-compatible. class TriangularSolveThunk : public Thunk { public: - TriangularSolveThunk(const TriangularSolveOptions& options, + TriangularSolveThunk(ThunkInfo thunk_info, + const TriangularSolveOptions& options, const BufferAllocation::Slice& a_buffer, const BufferAllocation::Slice& b_buffer, PrimitiveType type, int64 batch_size, int64 m, int64 n, - int64 a_batch_stride, int64 b_batch_stride, - const HloInstruction* hlo); + int64 a_batch_stride, int64 b_batch_stride); TriangularSolveThunk(const TriangularSolveThunk&) = delete; TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index cbbbb7baf68..c161a349cbb 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -34,7 +34,7 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) { } auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); + params.profiler->MakeScopedInstructionProfiler(profile_index()); SafeH2DMemcpy(se::DeviceMemory( buffer_allocations.GetDeviceAddress(dest_buffer_)), std::move(tuple_data), n, &stream, diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index b3b1ff62c4b..6d6709b5d47 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -34,10 +34,10 @@ namespace gpu { // issue (b/31336476). class TupleThunk : public Thunk { public: - TupleThunk(absl::Span tuple_element_buffers, - const BufferAllocation::Slice& dest_buffer, - const HloInstruction* hlo_instruction) - : Thunk(Kind::kTuple, hlo_instruction), + TupleThunk(ThunkInfo thunk_info, + absl::Span tuple_element_buffers, + const BufferAllocation::Slice& dest_buffer) + : Thunk(Kind::kTuple, thunk_info), tuple_element_buffers_(tuple_element_buffers.begin(), tuple_element_buffers.end()), dest_buffer_(dest_buffer) {} diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 4134cd39832..47a24552b6c 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -24,20 +24,20 @@ namespace xla { namespace gpu { WhileThunk::WhileThunk( + ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence, - const HloInstruction* hlo) - : Thunk(Kind::kWhile, hlo), + std::unique_ptr body_thunk_sequence) + : Thunk(Kind::kWhile, thunk_info), condition_result_buffer_index_(condition_result_buffer_index), // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ // and body_thunk_sequence_ constructors because these SequentialThunks // are logically "part of" this WhileThunk, and shouldn't be profiled // separately from it. condition_thunk_sequence_(absl::make_unique( - std::move(*condition_thunk_sequence), nullptr)), + ThunkInfo(), std::move(*condition_thunk_sequence))), body_thunk_sequence_(absl::make_unique( - std::move(*body_thunk_sequence), nullptr)) {} + ThunkInfo(), std::move(*body_thunk_sequence))) {} void WhileThunk::ComputeAnnotations() { Thunk::ComputeAnnotations(); @@ -61,7 +61,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { params.buffer_allocations->GetDeviceAddress( condition_result_buffer_index_); - auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction()); + auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index()); while (true) { // Invoke thunk sequence for while 'condition' computation. profiler.StartHloComputation(); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 31db01b72ba..72d9415b309 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -39,10 +39,10 @@ namespace gpu { class WhileThunk : public Thunk { public: // Constructs a WhileThunk to compute while instruction 'hlo'. - WhileThunk(const BufferAllocation::Slice& condition_result_buffer_index, + WhileThunk(ThunkInfo thunk_info, + const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence, - const HloInstruction* hlo); + std::unique_ptr body_thunk_sequence); WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index a7785455cf1..dff3a3495ab 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -133,8 +133,12 @@ HloExecutionProfile::HloExecutionProfile( void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { - profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] = - cycles_taken; + SetCyclesTakenBy(hlo_profile_index_map_.GetProfileIndexFor(*hlo), + cycles_taken); +} + +void HloExecutionProfile::SetCyclesTakenBy(size_t index, uint64 cycles_taken) { + profile_counters_[index] = cycles_taken; } uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 47de1a25765..02cba91a23e 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -114,6 +114,9 @@ class HloExecutionProfile { // Record how many cycles this HLO took to execute. void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); + // Record how many cycles this HLO took to execute. + void SetCyclesTakenBy(size_t index, uint64 cycles_taken); + // Returns how many cycles this HLO took to execute. Profiling information // may not be available for some instructions in which case zero is returned. uint64 GetCyclesTakenBy(const HloInstruction& hlo) const; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index 145d3681b16..dbc372be429 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -88,6 +88,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, const HloInstruction& hlo, const ShapeIndex& index) const override; int64 ByteSizeOf(const Shape& shape) const override; absl::string_view platform_name() const override; + mlir::Location getLocation(const HloInstruction* instr) const; xla::mlir_gpu::EmissionContext* emission_context_; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index eb901e59fd8..2c2076bbd97 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -436,8 +436,10 @@ StatusOr> TransformKernelToXlaThunk( kernel, operand_to_value_map, ordered_operands, assignment, buffers)); // Finally, create the thunk and set the launch dimensions. - auto thunk = absl::make_unique( - buffers, kernel.getName().str(), instr); + gpu::Thunk::ThunkInfo info; + info.hlo_instruction = instr; + auto thunk = absl::make_unique(info, buffers, + kernel.getName().str()); // Set launch bounds. mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();