[XLA/GPU] Remove uses of Thunk::hlo_instruction() for profiling.
This CL consists of two steps:
* First, refactor all Thunks to take an ThunkInfo instead of const HloInstruction*. This will benefit future extensions to ThunkInfo as we move away from HloInstruction*.
* Secondly, change the data pipeline from:
    Emitter -> Thunk* -> hlo_instruction() -> profiler(HloInstruction*)
  to:
    Emitter -> Thunk with profile indices
The profile doesn't really depend on HloInstruction*, but just its pointer
identity. Removing the dependency on HloInstruction helps with MLIR migration.
PiperOrigin-RevId: 320687291
Change-Id: I7027d4c032f73ed615e5b520e01f3740781735be
			
			
This commit is contained in:
		
							parent
							
								
									aa47bcc6f1
								
							
						
					
					
						commit
						5bbf4a1d11
					
				| @ -242,7 +242,6 @@ cc_library( | |||||||
|     deps = [ |     deps = [ | ||||||
|         ":backend_configs_cc", |         ":backend_configs_cc", | ||||||
|         ":buffer_allocations", |         ":buffer_allocations", | ||||||
|         ":cudnn_batchnorm_runner", |  | ||||||
|         ":elemental_ir_emitter", |         ":elemental_ir_emitter", | ||||||
|         ":gpu_constants", |         ":gpu_constants", | ||||||
|         ":gpu_conv_runner", |         ":gpu_conv_runner", | ||||||
| @ -267,6 +266,7 @@ cc_library( | |||||||
|         "//tensorflow/compiler/xla/service:elemental_ir_emitter", |         "//tensorflow/compiler/xla/service:elemental_ir_emitter", | ||||||
|         "//tensorflow/compiler/xla/service:hlo", |         "//tensorflow/compiler/xla/service:hlo", | ||||||
|         "//tensorflow/compiler/xla/service:hlo_casting_utils", |         "//tensorflow/compiler/xla/service:hlo_casting_utils", | ||||||
|  |         "//tensorflow/compiler/xla/service:hlo_execution_profile", | ||||||
|         "//tensorflow/compiler/xla/service:name_uniquer", |         "//tensorflow/compiler/xla/service:name_uniquer", | ||||||
|         "//tensorflow/compiler/xla/service:pattern_matcher", |         "//tensorflow/compiler/xla/service:pattern_matcher", | ||||||
|         "//tensorflow/compiler/xla/service:while_loop_analysis", |         "//tensorflow/compiler/xla/service:while_loop_analysis", | ||||||
| @ -282,7 +282,6 @@ cc_library( | |||||||
|         "//tensorflow/compiler/xla/service/llvm_ir:sort_util", |         "//tensorflow/compiler/xla/service/llvm_ir:sort_util", | ||||||
|         "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", |         "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", | ||||||
|         "//tensorflow/core:lib", |         "//tensorflow/core:lib", | ||||||
|         "//tensorflow/core:stream_executor_no_cuda", |  | ||||||
|         "@com_google_absl//absl/algorithm:container", |         "@com_google_absl//absl/algorithm:container", | ||||||
|         "@com_google_absl//absl/container:inlined_vector", |         "@com_google_absl//absl/container:inlined_vector", | ||||||
|         "@com_google_absl//absl/memory", |         "@com_google_absl//absl/memory", | ||||||
|  | |||||||
| @ -31,13 +31,13 @@ limitations under the License. | |||||||
| namespace xla { | namespace xla { | ||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, | CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, | ||||||
|  |                              const CholeskyOptions& options, | ||||||
|                              BufferAllocation::Slice a_buffer, |                              BufferAllocation::Slice a_buffer, | ||||||
|                              BufferAllocation::Slice workspace_buffer, |                              BufferAllocation::Slice workspace_buffer, | ||||||
|                              BufferAllocation::Slice info_buffer, |                              BufferAllocation::Slice info_buffer, | ||||||
|                              PrimitiveType type, int64 batch_size, int64 n, |                              PrimitiveType type, int64 batch_size, int64 n) | ||||||
|                              const HloInstruction* hlo) |     : Thunk(Kind::kCholesky, thunk_info), | ||||||
|     : Thunk(Kind::kCholesky, hlo), |  | ||||||
|       uplo_(options.lower() ? se::blas::UpperLower::kLower |       uplo_(options.lower() ? se::blas::UpperLower::kLower | ||||||
|                             : se::blas::UpperLower::kUpper), |                             : se::blas::UpperLower::kUpper), | ||||||
|       a_buffer_(a_buffer), |       a_buffer_(a_buffer), | ||||||
| @ -45,9 +45,10 @@ CholeskyThunk::CholeskyThunk(const CholeskyOptions& options, | |||||||
|       info_buffer_(info_buffer), |       info_buffer_(info_buffer), | ||||||
|       type_(type), |       type_(type), | ||||||
|       batch_size_(batch_size), |       batch_size_(batch_size), | ||||||
|       a_batch_stride_(n * n * |       a_batch_stride_( | ||||||
|                       ShapeUtil::ByteSizeOfPrimitiveType( |           n * n * | ||||||
|                           hlo->operand(0)->shape().element_type())), |           ShapeUtil::ByteSizeOfPrimitiveType( | ||||||
|  |               thunk_info.hlo_instruction->operand(0)->shape().element_type())), | ||||||
|       n_(n) {} |       n_(n) {} | ||||||
| 
 | 
 | ||||||
| Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { | Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|  | |||||||
| @ -41,12 +41,11 @@ namespace gpu { | |||||||
| class CholeskyThunk : public Thunk { | class CholeskyThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   static StatusOr<int64> ScratchBufferSize(int64 n); |   static StatusOr<int64> ScratchBufferSize(int64 n); | ||||||
|   CholeskyThunk(const CholeskyOptions& options, |   CholeskyThunk(ThunkInfo thunk_info, const CholeskyOptions& options, | ||||||
|                 BufferAllocation::Slice a_buffer, |                 BufferAllocation::Slice a_buffer, | ||||||
|                 BufferAllocation::Slice workspace_buffer, |                 BufferAllocation::Slice workspace_buffer, | ||||||
|                 BufferAllocation::Slice info_buffer, |                 BufferAllocation::Slice info_buffer, PrimitiveType type, | ||||||
|                 PrimitiveType type, |                 int64 batch_size, int64 n); | ||||||
|                 int64 batch_size, int64 n, const HloInstruction* hlo); |  | ||||||
| 
 | 
 | ||||||
|   CholeskyThunk(const CholeskyThunk&) = delete; |   CholeskyThunk(const CholeskyThunk&) = delete; | ||||||
|   CholeskyThunk& operator=(const CholeskyThunk&) = delete; |   CholeskyThunk& operator=(const CholeskyThunk&) = delete; | ||||||
|  | |||||||
| @ -218,14 +218,14 @@ RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() { | |||||||
| }  // anonymous namespace
 | }  // anonymous namespace
 | ||||||
| 
 | 
 | ||||||
| CollectivePermuteThunk::CollectivePermuteThunk( | CollectivePermuteThunk::CollectivePermuteThunk( | ||||||
|     const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest, |     ThunkInfo thunk_info, const BufferAllocation::Slice& src, | ||||||
|     const HloInstruction* instr) |     const BufferAllocation::Slice& dest) | ||||||
|     : Thunk(kCollectivePermute, instr), src_(src), dest_(dest) {} |     : Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {} | ||||||
| 
 | 
 | ||||||
| Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { | Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|   auto* instr = Cast<HloCollectivePermuteInstruction>(hlo_instruction()); |   auto* instr = Cast<HloCollectivePermuteInstruction>(hlo_instruction()); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
| 
 | 
 | ||||||
|   // Rendezvous with the threads for all other devices that are participating in
 |   // Rendezvous with the threads for all other devices that are participating in
 | ||||||
|   // this CollectivePermute.
 |   // this CollectivePermute.
 | ||||||
|  | |||||||
| @ -26,9 +26,9 @@ namespace gpu { | |||||||
| // Thunk that implements the collective-permute HLO.
 | // Thunk that implements the collective-permute HLO.
 | ||||||
| class CollectivePermuteThunk : public Thunk { | class CollectivePermuteThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   CollectivePermuteThunk(const BufferAllocation::Slice& src, |   CollectivePermuteThunk(ThunkInfo thunk_info, | ||||||
|                          const BufferAllocation::Slice& dest, |                          const BufferAllocation::Slice& src, | ||||||
|                          const HloInstruction* instr); |                          const BufferAllocation::Slice& dest); | ||||||
| 
 | 
 | ||||||
|   Status ExecuteOnStream(const ExecuteParams& params) override; |   Status ExecuteOnStream(const ExecuteParams& params) override; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -24,12 +24,14 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| ConditionalThunk::ConditionalThunk( | ConditionalThunk::ConditionalThunk( | ||||||
|  |     ThunkInfo thunk_info, | ||||||
|     const BufferAllocation::Slice& branch_index_buffer_index, |     const BufferAllocation::Slice& branch_index_buffer_index, | ||||||
|     absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes, |     absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes, | ||||||
|     std::vector<ThunkSequence> branch_thunk_sequences, |     std::vector<ThunkSequence> branch_thunk_sequences) | ||||||
|     const HloInstruction* hlo) |     : Thunk(Kind::kConditional, thunk_info), | ||||||
|     : Thunk(Kind::kConditional, hlo), |       branch_index_is_bool_( | ||||||
|       branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED), |           thunk_info.hlo_instruction->operand(0)->shape().element_type() == | ||||||
|  |           PRED), | ||||||
|       branch_index_buffer_index_(branch_index_buffer_index), |       branch_index_buffer_index_(branch_index_buffer_index), | ||||||
|       branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), |       branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), | ||||||
|                                      branch_operand_buffer_indexes.end()) { |                                      branch_operand_buffer_indexes.end()) { | ||||||
| @ -39,7 +41,7 @@ ConditionalThunk::ConditionalThunk( | |||||||
|   branch_thunks_.reserve(branch_thunk_sequences.size()); |   branch_thunks_.reserve(branch_thunk_sequences.size()); | ||||||
|   for (auto& branch_thunk_sequence : branch_thunk_sequences) { |   for (auto& branch_thunk_sequence : branch_thunk_sequences) { | ||||||
|     branch_thunks_.emplace_back( |     branch_thunks_.emplace_back( | ||||||
|         new SequentialThunk(std::move(branch_thunk_sequence), nullptr)); |         new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence))); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -67,7 +69,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   auto& profiler = *params.profiler; |   auto& profiler = *params.profiler; | ||||||
|   auto& stream = *params.stream; |   auto& stream = *params.stream; | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction()); |   auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index()); | ||||||
|   // Copy the predicate value from device.
 |   // Copy the predicate value from device.
 | ||||||
|   int32 branch_index = -1; |   int32 branch_index = -1; | ||||||
|   bool pred = false; |   bool pred = false; | ||||||
|  | |||||||
| @ -43,10 +43,10 @@ namespace gpu { | |||||||
| class ConditionalThunk : public Thunk { | class ConditionalThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   ConditionalThunk( |   ConditionalThunk( | ||||||
|  |       ThunkInfo thunk_info, | ||||||
|       const BufferAllocation::Slice& branch_index_buffer_index, |       const BufferAllocation::Slice& branch_index_buffer_index, | ||||||
|       absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes, |       absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes, | ||||||
|       std::vector<ThunkSequence> branch_thunk_sequences, |       std::vector<ThunkSequence> branch_thunk_sequences); | ||||||
|       const HloInstruction* hlo); |  | ||||||
| 
 | 
 | ||||||
|   ConditionalThunk(const ConditionalThunk&) = delete; |   ConditionalThunk(const ConditionalThunk&) = delete; | ||||||
|   ConditionalThunk& operator=(const ConditionalThunk&) = delete; |   ConditionalThunk& operator=(const ConditionalThunk&) = delete; | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ limitations under the License. | |||||||
| #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" | #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" | ||||||
| #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" | #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" | ||||||
| #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" | #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" | ||||||
|  | #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" | ||||||
| #include "tensorflow/compiler/xla/types.h" | #include "tensorflow/compiler/xla/types.h" | ||||||
| #include "tensorflow/compiler/xla/util.h" | #include "tensorflow/compiler/xla/util.h" | ||||||
| #include "tensorflow/core/platform/logging.h" | #include "tensorflow/core/platform/logging.h" | ||||||
| @ -30,16 +31,16 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| ConvolutionThunk::ConvolutionThunk( | ConvolutionThunk::ConvolutionThunk( | ||||||
|     const HloCustomCallInstruction* cudnn_call, |     ThunkInfo thunk_info, std::vector<BufferAllocation::Slice> operand_slices, | ||||||
|     std::vector<BufferAllocation::Slice> operand_slices, |  | ||||||
|     BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, |     BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, | ||||||
|     BufferAllocation::Slice tuple_result_slice) |     BufferAllocation::Slice tuple_result_slice) | ||||||
|     : Thunk(Kind::kConvolution, cudnn_call), |     : Thunk(Kind::kConvolution, thunk_info), | ||||||
|       cudnn_call_(cudnn_call), |  | ||||||
|       operand_buffers_(std::move(operand_slices)), |       operand_buffers_(std::move(operand_slices)), | ||||||
|       result_buffer_(result_slice), |       result_buffer_(result_slice), | ||||||
|       scratch_buffer_(scratch_slice), |       scratch_buffer_(scratch_slice), | ||||||
|       tuple_result_buffer_(tuple_result_slice) {} |       tuple_result_buffer_(tuple_result_slice) { | ||||||
|  |   cudnn_call_ = Cast<HloCustomCallInstruction>(hlo_instruction()); | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { | Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|   const auto& buffer_allocations = *params.buffer_allocations; |   const auto& buffer_allocations = *params.buffer_allocations; | ||||||
| @ -56,7 +57,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|       buffer_allocations.GetDeviceAddress(scratch_buffer_); |       buffer_allocations.GetDeviceAddress(scratch_buffer_); | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers), |   TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers), | ||||||
|                                 result_buffer, scratch, params.stream)); |                                 result_buffer, scratch, params.stream)); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk { | |||||||
|   // write a tuple (result, scratch_memory) into `tuple_result_buffer`.
 |   // write a tuple (result, scratch_memory) into `tuple_result_buffer`.
 | ||||||
|   //
 |   //
 | ||||||
|   // operand_slices should be in the same order as cudnn_call->operands().
 |   // operand_slices should be in the same order as cudnn_call->operands().
 | ||||||
|   ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, |   ConvolutionThunk(ThunkInfo thunk_info, | ||||||
|                    std::vector<BufferAllocation::Slice> operand_slices, |                    std::vector<BufferAllocation::Slice> operand_slices, | ||||||
|                    BufferAllocation::Slice result_slice, |                    BufferAllocation::Slice result_slice, | ||||||
|                    BufferAllocation::Slice scratch_slice, |                    BufferAllocation::Slice scratch_slice, | ||||||
|  | |||||||
| @ -22,10 +22,9 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| HostToDeviceCopyThunk::HostToDeviceCopyThunk( | HostToDeviceCopyThunk::HostToDeviceCopyThunk( | ||||||
|     const void* source_address, |     ThunkInfo thunk_info, const void* source_address, | ||||||
|     const BufferAllocation::Slice& destination_buffer, uint64 mem_size, |     const BufferAllocation::Slice& destination_buffer, uint64 mem_size) | ||||||
|     const HloInstruction* hlo_instruction) |     : Thunk(Kind::kCopy, thunk_info), | ||||||
|     : Thunk(Kind::kCopy, hlo_instruction), |  | ||||||
|       source_address_(source_address), |       source_address_(source_address), | ||||||
|       destination_buffer_(destination_buffer), |       destination_buffer_(destination_buffer), | ||||||
|       mem_size_(mem_size) {} |       mem_size_(mem_size) {} | ||||||
| @ -34,16 +33,15 @@ Status HostToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   se::DeviceMemoryBase destination_data = |   se::DeviceMemoryBase destination_data = | ||||||
|       params.buffer_allocations->GetDeviceAddress(destination_buffer_); |       params.buffer_allocations->GetDeviceAddress(destination_buffer_); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_); |   params.stream->ThenMemcpy(&destination_data, source_address_, mem_size_); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( | DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( | ||||||
|     const BufferAllocation::Slice& source_buffer, |     ThunkInfo thunk_info, const BufferAllocation::Slice& source_buffer, | ||||||
|     const BufferAllocation::Slice& destination_buffer, uint64 mem_size, |     const BufferAllocation::Slice& destination_buffer, uint64 mem_size) | ||||||
|     const HloInstruction* hlo_instruction) |     : Thunk(Kind::kCopy, thunk_info), | ||||||
|     : Thunk(Kind::kCopy, hlo_instruction), |  | ||||||
|       source_buffer_(source_buffer), |       source_buffer_(source_buffer), | ||||||
|       destination_buffer_(destination_buffer), |       destination_buffer_(destination_buffer), | ||||||
|       mem_size_(mem_size) {} |       mem_size_(mem_size) {} | ||||||
| @ -54,7 +52,7 @@ Status DeviceToDeviceCopyThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   se::DeviceMemoryBase source_data = |   se::DeviceMemoryBase source_data = | ||||||
|       params.buffer_allocations->GetDeviceAddress(source_buffer_); |       params.buffer_allocations->GetDeviceAddress(source_buffer_); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   params.stream->ThenMemcpy(&destination_data, source_data, mem_size_); |   params.stream->ThenMemcpy(&destination_data, source_data, mem_size_); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -33,9 +33,9 @@ class HostToDeviceCopyThunk : public Thunk { | |||||||
|   // Constructs a CopyThunk that copies host data from `source_address` to the
 |   // Constructs a CopyThunk that copies host data from `source_address` to the
 | ||||||
|   // device buffer `destination_buffer`. `mem_size` is the size of the data in
 |   // device buffer `destination_buffer`. `mem_size` is the size of the data in
 | ||||||
|   // bytes.
 |   // bytes.
 | ||||||
|   HostToDeviceCopyThunk(const void* source_address, |   HostToDeviceCopyThunk(ThunkInfo thunk_info, const void* source_address, | ||||||
|                         const BufferAllocation::Slice& destination_buffer, |                         const BufferAllocation::Slice& destination_buffer, | ||||||
|                         uint64 mem_size, const HloInstruction* hlo_instruction); |                         uint64 mem_size); | ||||||
| 
 | 
 | ||||||
|   HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; |   HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; | ||||||
|   HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; |   HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; | ||||||
| @ -54,10 +54,10 @@ class DeviceToDeviceCopyThunk : public Thunk { | |||||||
|   // Constructs a CopyThunk that copies host data from `source_buffer` to the
 |   // Constructs a CopyThunk that copies host data from `source_buffer` to the
 | ||||||
|   // device buffer `destination_buffer`. `mem_size` is the size of the data in
 |   // device buffer `destination_buffer`. `mem_size` is the size of the data in
 | ||||||
|   // bytes.
 |   // bytes.
 | ||||||
|   DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer, |   DeviceToDeviceCopyThunk(ThunkInfo thunk_info, | ||||||
|  |                           const BufferAllocation::Slice& source_buffer, | ||||||
|                           const BufferAllocation::Slice& destination_buffer, |                           const BufferAllocation::Slice& destination_buffer, | ||||||
|                           uint64 mem_size, |                           uint64 mem_size); | ||||||
|                           const HloInstruction* hlo_instruction); |  | ||||||
| 
 | 
 | ||||||
|   DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; |   DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; | ||||||
|   DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; |   DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; | ||||||
|  | |||||||
| @ -92,12 +92,12 @@ void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { | |||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( | CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( | ||||||
|     const BufferAllocation::Slice& operand, |     ThunkInfo thunk_info, const BufferAllocation::Slice& operand, | ||||||
|     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, |     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, | ||||||
|     const BufferAllocation::Slice& mean, |     const BufferAllocation::Slice& mean, | ||||||
|     const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, |     const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, | ||||||
|     const BufferAllocation::Slice& output, const HloInstruction* hlo) |     const BufferAllocation::Slice& output) | ||||||
|     : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo), |     : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info), | ||||||
|       operand_(operand), |       operand_(operand), | ||||||
|       scale_(scale), |       scale_(scale), | ||||||
|       offset_(offset), |       offset_(offset), | ||||||
| @ -106,6 +106,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( | |||||||
|       epsilon_(epsilon), |       epsilon_(epsilon), | ||||||
|       feature_index_(feature_index), |       feature_index_(feature_index), | ||||||
|       output_(output) { |       output_(output) { | ||||||
|  |   const auto* hlo = hlo_instruction(); | ||||||
|   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); |   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); | ||||||
|   CHECK_EQ(hlo->custom_call_target(), |   CHECK_EQ(hlo->custom_call_target(), | ||||||
|            kCudnnBatchNormForwardInferenceCallTarget); |            kCudnnBatchNormForwardInferenceCallTarget); | ||||||
| @ -118,7 +119,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( | |||||||
|     const ExecuteParams& params) { |     const ExecuteParams& params) { | ||||||
|   auto& buffer_allocations = *params.buffer_allocations; |   auto& buffer_allocations = *params.buffer_allocations; | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   se::DeviceMemoryBase output_base = |   se::DeviceMemoryBase output_base = | ||||||
|       buffer_allocations.GetDeviceAddress(output_); |       buffer_allocations.GetDeviceAddress(output_); | ||||||
|   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_); |   se::DeviceMemoryBase operand = buffer_allocations.GetDeviceAddress(operand_); | ||||||
| @ -139,14 +140,14 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( | CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( | ||||||
|     const BufferAllocation::Slice& operand, |     ThunkInfo thunk_info, const BufferAllocation::Slice& operand, | ||||||
|     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, |     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, | ||||||
|     float epsilon, int64 feature_index, |     float epsilon, int64 feature_index, | ||||||
|     const BufferAllocation::Slice& output_data, |     const BufferAllocation::Slice& output_data, | ||||||
|     const BufferAllocation::Slice& output_mean, |     const BufferAllocation::Slice& output_mean, | ||||||
|     const BufferAllocation::Slice& output_inv_stddev, |     const BufferAllocation::Slice& output_inv_stddev, | ||||||
|     const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) |     const BufferAllocation::Slice& output_tuple) | ||||||
|     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo), |     : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), | ||||||
|       operand_(operand), |       operand_(operand), | ||||||
|       scale_(scale), |       scale_(scale), | ||||||
|       offset_(offset), |       offset_(offset), | ||||||
| @ -156,6 +157,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( | |||||||
|       output_mean_(output_mean), |       output_mean_(output_mean), | ||||||
|       output_inv_stddev_(output_inv_stddev), |       output_inv_stddev_(output_inv_stddev), | ||||||
|       output_tuple_(output_tuple) { |       output_tuple_(output_tuple) { | ||||||
|  |   const auto* hlo = hlo_instruction(); | ||||||
|   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); |   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); | ||||||
|   CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); |   CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); | ||||||
|   CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); |   CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); | ||||||
| @ -178,7 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( | |||||||
| 
 | 
 | ||||||
|   se::DeviceMemory<float> null_device_ptr(nullptr); |   se::DeviceMemory<float> null_device_ptr(nullptr); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   auto& stream = *params.stream; |   auto& stream = *params.stream; | ||||||
|   TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( |   TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( | ||||||
|       hlo_instruction(), operand, output_data, output_mean, output_inv_stddev, |       hlo_instruction(), operand, output_data, output_mean, output_inv_stddev, | ||||||
| @ -203,15 +205,15 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( | CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( | ||||||
|     const BufferAllocation::Slice& operand, |     ThunkInfo thunk_info, const BufferAllocation::Slice& operand, | ||||||
|     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, |     const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, | ||||||
|     const BufferAllocation::Slice& inv_stddev, |     const BufferAllocation::Slice& inv_stddev, | ||||||
|     const BufferAllocation::Slice& grad_output, float epsilon, |     const BufferAllocation::Slice& grad_output, float epsilon, | ||||||
|     int64 feature_index, const BufferAllocation::Slice& output_grad_data, |     int64 feature_index, const BufferAllocation::Slice& output_grad_data, | ||||||
|     const BufferAllocation::Slice& output_grad_scale, |     const BufferAllocation::Slice& output_grad_scale, | ||||||
|     const BufferAllocation::Slice& output_grad_offset, |     const BufferAllocation::Slice& output_grad_offset, | ||||||
|     const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) |     const BufferAllocation::Slice& output_tuple) | ||||||
|     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo), |     : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), | ||||||
|       operand_(operand), |       operand_(operand), | ||||||
|       scale_(scale), |       scale_(scale), | ||||||
|       mean_(mean), |       mean_(mean), | ||||||
| @ -223,6 +225,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( | |||||||
|       output_grad_scale_(output_grad_scale), |       output_grad_scale_(output_grad_scale), | ||||||
|       output_grad_offset_(output_grad_offset), |       output_grad_offset_(output_grad_offset), | ||||||
|       output_tuple_(output_tuple) { |       output_tuple_(output_tuple) { | ||||||
|  |   const auto* hlo = hlo_instruction(); | ||||||
|   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); |   CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); | ||||||
|   CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); |   CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); | ||||||
|   CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); |   CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); | ||||||
| @ -247,7 +250,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( | |||||||
|       buffer_allocations.GetDeviceAddress(output_grad_offset_)); |       buffer_allocations.GetDeviceAddress(output_grad_offset_)); | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   se::Stream* stream = params.stream; |   se::Stream* stream = params.stream; | ||||||
|   TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( |   TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( | ||||||
|       hlo_instruction(), operand, output_grad_data, grad_output, |       hlo_instruction(), operand, output_grad_data, grad_output, | ||||||
|  | |||||||
| @ -46,14 +46,14 @@ namespace gpu { | |||||||
| 
 | 
 | ||||||
| class CudnnBatchNormForwardInferenceThunk : public Thunk { | class CudnnBatchNormForwardInferenceThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand, |   CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info, | ||||||
|  |                                       const BufferAllocation::Slice& operand, | ||||||
|                                       const BufferAllocation::Slice& scale, |                                       const BufferAllocation::Slice& scale, | ||||||
|                                       const BufferAllocation::Slice& offset, |                                       const BufferAllocation::Slice& offset, | ||||||
|                                       const BufferAllocation::Slice& mean, |                                       const BufferAllocation::Slice& mean, | ||||||
|                                       const BufferAllocation::Slice& variance, |                                       const BufferAllocation::Slice& variance, | ||||||
|                                       float epsilon, int64 feature_index, |                                       float epsilon, int64 feature_index, | ||||||
|                                       const BufferAllocation::Slice& output, |                                       const BufferAllocation::Slice& output); | ||||||
|                                       const HloInstruction* hlo); |  | ||||||
| 
 | 
 | ||||||
|   CudnnBatchNormForwardInferenceThunk( |   CudnnBatchNormForwardInferenceThunk( | ||||||
|       const CudnnBatchNormForwardInferenceThunk&) = delete; |       const CudnnBatchNormForwardInferenceThunk&) = delete; | ||||||
| @ -76,13 +76,13 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { | |||||||
| class CudnnBatchNormForwardTrainingThunk : public Thunk { | class CudnnBatchNormForwardTrainingThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   CudnnBatchNormForwardTrainingThunk( |   CudnnBatchNormForwardTrainingThunk( | ||||||
|       const BufferAllocation::Slice& operand, |       ThunkInfo thunk_info, const BufferAllocation::Slice& operand, | ||||||
|       const BufferAllocation::Slice& scale, |       const BufferAllocation::Slice& scale, | ||||||
|       const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, |       const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, | ||||||
|       const BufferAllocation::Slice& output_data, |       const BufferAllocation::Slice& output_data, | ||||||
|       const BufferAllocation::Slice& output_mean, |       const BufferAllocation::Slice& output_mean, | ||||||
|       const BufferAllocation::Slice& output_inv_stddev, |       const BufferAllocation::Slice& output_inv_stddev, | ||||||
|       const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo); |       const BufferAllocation::Slice& output_tuple); | ||||||
| 
 | 
 | ||||||
|   CudnnBatchNormForwardTrainingThunk( |   CudnnBatchNormForwardTrainingThunk( | ||||||
|       const CudnnBatchNormForwardTrainingThunk&) = delete; |       const CudnnBatchNormForwardTrainingThunk&) = delete; | ||||||
| @ -105,7 +105,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { | |||||||
| 
 | 
 | ||||||
| class CudnnBatchNormBackwardThunk : public Thunk { | class CudnnBatchNormBackwardThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand, |   CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, | ||||||
|  |                               const BufferAllocation::Slice& operand, | ||||||
|                               const BufferAllocation::Slice& scale, |                               const BufferAllocation::Slice& scale, | ||||||
|                               const BufferAllocation::Slice& mean, |                               const BufferAllocation::Slice& mean, | ||||||
|                               const BufferAllocation::Slice& inv_stddev, |                               const BufferAllocation::Slice& inv_stddev, | ||||||
| @ -114,8 +115,7 @@ class CudnnBatchNormBackwardThunk : public Thunk { | |||||||
|                               const BufferAllocation::Slice& output_grad_data, |                               const BufferAllocation::Slice& output_grad_data, | ||||||
|                               const BufferAllocation::Slice& output_grad_scale, |                               const BufferAllocation::Slice& output_grad_scale, | ||||||
|                               const BufferAllocation::Slice& output_grad_offset, |                               const BufferAllocation::Slice& output_grad_offset, | ||||||
|                               const BufferAllocation::Slice& output_tuple, |                               const BufferAllocation::Slice& output_tuple); | ||||||
|                               const HloInstruction* hlo); |  | ||||||
| 
 | 
 | ||||||
|   CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; |   CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; | ||||||
|   CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = |   CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = | ||||||
|  | |||||||
| @ -22,15 +22,15 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| CustomCallThunk::CustomCallThunk( | CustomCallThunk::CustomCallThunk( | ||||||
|     void* call_target, |     ThunkInfo thunk_info, void* call_target, | ||||||
|     std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices, |     std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices, | ||||||
|     ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque, |     ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque) | ||||||
|     const HloInstruction* instr) |     : Thunk(Thunk::kCustomCall, thunk_info), | ||||||
|     : Thunk(Thunk::kCustomCall, instr), |  | ||||||
|       call_target_(call_target), |       call_target_(call_target), | ||||||
|       operand_slices_(std::move(operand_slices)), |       operand_slices_(std::move(operand_slices)), | ||||||
|       result_slices_(std::move(result_slices)), |       result_slices_(std::move(result_slices)), | ||||||
|       opaque_(std::move(opaque)) { |       opaque_(std::move(opaque)) { | ||||||
|  |   const HloInstruction* instr = hlo_instruction(); | ||||||
|   CHECK_EQ(instr->operand_count(), operand_slices_.size()); |   CHECK_EQ(instr->operand_count(), operand_slices_.size()); | ||||||
|   for (int64 i = 0; i < instr->operand_count(); ++i) { |   for (int64 i = 0; i < instr->operand_count(); ++i) { | ||||||
|     const auto& s1 = operand_slices_[i].shape(); |     const auto& s1 = operand_slices_[i].shape(); | ||||||
|  | |||||||
| @ -39,10 +39,9 @@ namespace gpu { | |||||||
| class CustomCallThunk : public Thunk { | class CustomCallThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   CustomCallThunk( |   CustomCallThunk( | ||||||
|       void* call_target, |       ThunkInfo thunk_info, void* call_target, | ||||||
|       std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices, |       std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices, | ||||||
|       ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque, |       ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque); | ||||||
|       const HloInstruction* instr); |  | ||||||
| 
 | 
 | ||||||
|   Status ExecuteOnStream(const ExecuteParams& params) override; |   Status ExecuteOnStream(const ExecuteParams& params) override; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -42,9 +42,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { | |||||||
| struct NcclAllReduceThunk::AuxData {}; | struct NcclAllReduceThunk::AuxData {}; | ||||||
| 
 | 
 | ||||||
| NcclAllReduceThunk::NcclAllReduceThunk( | NcclAllReduceThunk::NcclAllReduceThunk( | ||||||
|     int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers, |     ThunkInfo thunk_info, int64 replica_count, | ||||||
|     const HloInstruction* all_reduce) |     std::vector<NcclAllReduceThunk::Buffer> buffers) | ||||||
|     : Thunk(Thunk::kNcclAllReduce, all_reduce), |     : Thunk(Thunk::kNcclAllReduce, thunk_info), | ||||||
|       replica_count_(replica_count), |       replica_count_(replica_count), | ||||||
|       buffers_(std::move(buffers)) {} |       buffers_(std::move(buffers)) {} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -98,12 +98,12 @@ string FftTypeToString(se::fft::Type type) { | |||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length, | FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type, | ||||||
|  |                    absl::Span<const int64> fft_length, | ||||||
|                    const BufferAllocation::Slice& input_buffer, |                    const BufferAllocation::Slice& input_buffer, | ||||||
|                    const BufferAllocation::Slice& output_buffer, |                    const BufferAllocation::Slice& output_buffer, | ||||||
|                    const Shape& input_shape, const Shape& output_shape, |                    const Shape& input_shape, const Shape& output_shape) | ||||||
|                    const HloInstruction* hlo) |     : Thunk(Kind::kFft, thunk_info), | ||||||
|     : Thunk(Kind::kFft, hlo), |  | ||||||
|       fft_type_( |       fft_type_( | ||||||
|           FftTypeToSeType(fft_type, input_shape.element_type() == F64 || |           FftTypeToSeType(fft_type, input_shape.element_type() == F64 || | ||||||
|                                         input_shape.element_type() == C128)), |                                         input_shape.element_type() == C128)), | ||||||
| @ -127,7 +127,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|                                         buffer_allocations.memory_allocator()); |                                         buffer_allocations.memory_allocator()); | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   if (fft_plan_ == nullptr) { |   if (fft_plan_ == nullptr) { | ||||||
|     const int64 fft_rank = fft_length_.size(); |     const int64 fft_rank = fft_length_.size(); | ||||||
|     CHECK_LE(fft_rank, 3); |     CHECK_LE(fft_rank, 3); | ||||||
|  | |||||||
| @ -62,11 +62,11 @@ class FftThunk : public Thunk { | |||||||
|  public: |  public: | ||||||
|   // Constructs a thunk for launching an FFT on a stream.
 |   // Constructs a thunk for launching an FFT on a stream.
 | ||||||
|   // Semantics of null hlo_instruction argument are as in Thunk.
 |   // Semantics of null hlo_instruction argument are as in Thunk.
 | ||||||
|   FftThunk(FftType fft_type, absl::Span<const int64> fft_length, |   FftThunk(ThunkInfo thunk_info, FftType fft_type, | ||||||
|  |            absl::Span<const int64> fft_length, | ||||||
|            const BufferAllocation::Slice& input_buffer, |            const BufferAllocation::Slice& input_buffer, | ||||||
|            const BufferAllocation::Slice& output_buffer, |            const BufferAllocation::Slice& output_buffer, | ||||||
|            const Shape& input_shape, const Shape& output_shape, |            const Shape& input_shape, const Shape& output_shape); | ||||||
|            const HloInstruction* hlo); |  | ||||||
| 
 | 
 | ||||||
|   FftThunk(const FftThunk&) = delete;             // Cannot share fft_plan_
 |   FftThunk(const FftThunk&) = delete;             // Cannot share fft_plan_
 | ||||||
|   FftThunk& operator=(const FftThunk&) = delete;  // Cannot share fft_plan_
 |   FftThunk& operator=(const FftThunk&) = delete;  // Cannot share fft_plan_
 | ||||||
|  | |||||||
| @ -23,16 +23,15 @@ limitations under the License. | |||||||
| namespace xla { | namespace xla { | ||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| ForThunk::ForThunk(const int64 loop_limit, | ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, | ||||||
|                    std::unique_ptr<ThunkSequence> body_thunk_sequence, |                    std::unique_ptr<ThunkSequence> body_thunk_sequence) | ||||||
|                    const HloInstruction* hlo) |     : Thunk(Kind::kWhile, thunk_info), | ||||||
|     : Thunk(Kind::kWhile, hlo), |  | ||||||
|       loop_limit_(loop_limit), |       loop_limit_(loop_limit), | ||||||
|       body_thunk_sequence_(absl::make_unique<SequentialThunk>( |       body_thunk_sequence_(absl::make_unique<SequentialThunk>( | ||||||
|           // Pass nullptr as the HloInstruction* to the body_thunk_sequence_
 |           // Pass nullptr as the HloInstruction* to the body_thunk_sequence_
 | ||||||
|           // constructor because this SequentialThunk is logically "part of"
 |           // constructor because this SequentialThunk is logically "part of"
 | ||||||
|           // this ForThunk, and shouldn't be profiled separately from it.
 |           // this ForThunk, and shouldn't be profiled separately from it.
 | ||||||
|           std::move(*body_thunk_sequence), nullptr)) {} |           ThunkInfo(), std::move(*body_thunk_sequence))) {} | ||||||
| 
 | 
 | ||||||
| void ForThunk::ComputeAnnotations() { | void ForThunk::ComputeAnnotations() { | ||||||
|   Thunk::ComputeAnnotations(); |   Thunk::ComputeAnnotations(); | ||||||
| @ -49,7 +48,7 @@ Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " |   VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " | ||||||
|           << (hlo_instruction() ? hlo_instruction()->ToString() : "<null>"); |           << (hlo_instruction() ? hlo_instruction()->ToString() : "<null>"); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   for (int64 i = 0; i < loop_limit_; ++i) { |   for (int64 i = 0; i < loop_limit_; ++i) { | ||||||
|     params.profiler->StartHloComputation(); |     params.profiler->StartHloComputation(); | ||||||
|     // Invoke loop body thunk sequence.
 |     // Invoke loop body thunk sequence.
 | ||||||
|  | |||||||
| @ -31,9 +31,8 @@ namespace gpu { | |||||||
| // ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
 | // ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
 | ||||||
| class ForThunk : public Thunk { | class ForThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   ForThunk(const int64 loop_limit, |   ForThunk(ThunkInfo thunk_info, const int64 loop_limit, | ||||||
|            std::unique_ptr<ThunkSequence> body_thunk_sequence, |            std::unique_ptr<ThunkSequence> body_thunk_sequence); | ||||||
|            const HloInstruction* hlo); |  | ||||||
|   ForThunk(const ForThunk&) = delete; |   ForThunk(const ForThunk&) = delete; | ||||||
|   ForThunk& operator=(const ForThunk&) = delete; |   ForThunk& operator=(const ForThunk&) = delete; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -132,6 +132,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune( | |||||||
|     CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer, |     CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer, | ||||||
|                   stream, |                   stream, | ||||||
|                   /*implements_whole_instruction=*/true, |                   /*implements_whole_instruction=*/true, | ||||||
|  |                   /*profile_index=*/-1, | ||||||
|                   /*profiler=*/nullptr, |                   /*profiler=*/nullptr, | ||||||
|                   /*profile_result=*/&profile_result, algorithm) |                   /*profile_result=*/&profile_result, algorithm) | ||||||
|               .ok()); |               .ok()); | ||||||
|  | |||||||
| @ -33,13 +33,13 @@ limitations under the License. | |||||||
| namespace xla { | namespace xla { | ||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer, | GemmThunk::GemmThunk(ThunkInfo thunk_info, | ||||||
|  |                      const BufferAllocation::Slice &lhs_buffer, | ||||||
|                      const BufferAllocation::Slice &rhs_buffer, |                      const BufferAllocation::Slice &rhs_buffer, | ||||||
|                      const BufferAllocation::Slice &output_buffer, |                      const BufferAllocation::Slice &output_buffer, | ||||||
|                      bool implements_whole_instruction, |                      bool implements_whole_instruction, | ||||||
|                      const HloInstruction *hlo_instruction, |  | ||||||
|                      const GemmBackendConfig &backend_config) |                      const GemmBackendConfig &backend_config) | ||||||
|     : Thunk(Kind::kGemm, hlo_instruction), |     : Thunk(Kind::kGemm, thunk_info), | ||||||
|       lhs_buffer_(lhs_buffer), |       lhs_buffer_(lhs_buffer), | ||||||
|       rhs_buffer_(rhs_buffer), |       rhs_buffer_(rhs_buffer), | ||||||
|       output_buffer_(output_buffer), |       output_buffer_(output_buffer), | ||||||
| @ -57,7 +57,7 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { | |||||||
|   se::DeviceMemoryBase output_data = get_device_address(output_buffer_); |   se::DeviceMemoryBase output_data = get_device_address(output_buffer_); | ||||||
|   return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data, |   return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data, | ||||||
|                  output_data, params.stream, implements_whole_instruction_, |                  output_data, params.stream, implements_whole_instruction_, | ||||||
|                  params.profiler); |                  profile_index(), params.profiler); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // This struct contains the metadata of a matrix, e.g., its base address and
 | // This struct contains the metadata of a matrix, e.g., its base address and
 | ||||||
| @ -160,6 +160,7 @@ Status RunGemm(const HloInstruction *gemm, | |||||||
|                se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, |                se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, | ||||||
|                se::DeviceMemoryBase output_buffer, se::Stream *stream, |                se::DeviceMemoryBase output_buffer, se::Stream *stream, | ||||||
|                bool implements_whole_instruction, |                bool implements_whole_instruction, | ||||||
|  |                absl::optional<int64> profile_index, | ||||||
|                HloExecutionProfiler *profiler, |                HloExecutionProfiler *profiler, | ||||||
|                se::blas::ProfileResult *profile_result, |                se::blas::ProfileResult *profile_result, | ||||||
|                absl::optional<se::blas::AlgorithmType> algorithm) { |                absl::optional<se::blas::AlgorithmType> algorithm) { | ||||||
| @ -240,7 +241,7 @@ Status RunGemm(const HloInstruction *gemm, | |||||||
|       rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); |       rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); | ||||||
|   std::unique_ptr<ScopedInstructionProfiler> op_profiler = |   std::unique_ptr<ScopedInstructionProfiler> op_profiler = | ||||||
|       profiler ? profiler->MakeScopedInstructionProfiler( |       profiler ? profiler->MakeScopedInstructionProfiler( | ||||||
|                      implements_whole_instruction ? gemm : nullptr) |                      implements_whole_instruction ? profile_index : -1) | ||||||
|                : nullptr; |                : nullptr; | ||||||
| 
 | 
 | ||||||
|   if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) { |   if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) { | ||||||
|  | |||||||
| @ -39,11 +39,10 @@ class GemmThunk : public Thunk { | |||||||
|  public: |  public: | ||||||
|   // Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
 |   // Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
 | ||||||
|   // BLAS gemm (alpha is stored in the instruction GemmBackendConfig).
 |   // BLAS gemm (alpha is stored in the instruction GemmBackendConfig).
 | ||||||
|   GemmThunk(const BufferAllocation::Slice& lhs_buffer, |   GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer, | ||||||
|             const BufferAllocation::Slice& rhs_buffer, |             const BufferAllocation::Slice& rhs_buffer, | ||||||
|             const BufferAllocation::Slice& output_buffer, |             const BufferAllocation::Slice& output_buffer, | ||||||
|             bool implements_whole_instruction, |             bool implements_whole_instruction, | ||||||
|             const HloInstruction* hlo_instruction, |  | ||||||
|             const GemmBackendConfig& backend_config); |             const GemmBackendConfig& backend_config); | ||||||
| 
 | 
 | ||||||
|   GemmThunk(const GemmThunk&) = delete; |   GemmThunk(const GemmThunk&) = delete; | ||||||
| @ -72,7 +71,8 @@ Status RunGemm( | |||||||
|     const HloInstruction* gemm, const GemmBackendConfig& backend_config, |     const HloInstruction* gemm, const GemmBackendConfig& backend_config, | ||||||
|     se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, |     se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, | ||||||
|     se::DeviceMemoryBase output_buffer, se::Stream* stream, |     se::DeviceMemoryBase output_buffer, se::Stream* stream, | ||||||
|     bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr, |     bool implements_whole_instruction, absl::optional<int64> profile_index, | ||||||
|  |     HloExecutionProfiler* profiler = nullptr, | ||||||
|     se::blas::ProfileResult* profile_result = nullptr, |     se::blas::ProfileResult* profile_result = nullptr, | ||||||
|     absl::optional<se::blas::AlgorithmType> algorithm = absl::nullopt); |     absl::optional<se::blas::AlgorithmType> algorithm = absl::nullopt); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -472,7 +472,8 @@ static Status CompileModuleToLlvmIrImpl( | |||||||
|     const std::string& platform_name, GpuDeviceInfo gpu_device_info, |     const std::string& platform_name, GpuDeviceInfo gpu_device_info, | ||||||
|     absl::optional<CudaComputeCapability> cuda_compute_capability, |     absl::optional<CudaComputeCapability> cuda_compute_capability, | ||||||
|     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function, |     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function, | ||||||
|     int pointer_size, std::unique_ptr<llvm::Module>* llvm_module, |     int pointer_size, const HloProfileIndexMap* profile_index_map, | ||||||
|  |     std::unique_ptr<llvm::Module>* llvm_module, | ||||||
|     std::unique_ptr<BufferAssignment>* buffer_assignment, |     std::unique_ptr<BufferAssignment>* buffer_assignment, | ||||||
|     std::unique_ptr<ThunkSchedule>* thunk_schedule) { |     std::unique_ptr<ThunkSchedule>* thunk_schedule) { | ||||||
|   *llvm_module = absl::make_unique<llvm::Module>("", *llvm_context); |   *llvm_module = absl::make_unique<llvm::Module>("", *llvm_context); | ||||||
| @ -509,7 +510,7 @@ static Status CompileModuleToLlvmIrImpl( | |||||||
| 
 | 
 | ||||||
|   IrEmitterContext ir_emitter_context( |   IrEmitterContext ir_emitter_context( | ||||||
|       hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, |       hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, | ||||||
|       cuda_compute_capability, llvm_module->get()); |       cuda_compute_capability, profile_index_map, llvm_module->get()); | ||||||
| 
 | 
 | ||||||
|   HloComputation* entry_computation = hlo_module->entry_computation(); |   HloComputation* entry_computation = hlo_module->entry_computation(); | ||||||
|   IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, |   IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, | ||||||
| @ -532,10 +533,14 @@ static Status CompileModuleToLlvmIrImpl( | |||||||
|       // not all explicitly checked, but at least we can document them here:
 |       // not all explicitly checked, but at least we can document them here:
 | ||||||
|       // * The entry HloComputation shall not have dead code (all reachable from
 |       // * The entry HloComputation shall not have dead code (all reachable from
 | ||||||
|       // ROOT).
 |       // ROOT).
 | ||||||
|       // * For each visit of HloInstruction, either none or one Thunk will be
 |       // * The visited instructions are all instructions in the entry
 | ||||||
|       // returned.
 |       // computation.
 | ||||||
|  |       // * For each visit of these HloInstructions, either none or one Thunk
 | ||||||
|  |       // will be returned.
 | ||||||
|       // * If there is a thunk returned, thunk->hlo_instruction() equals the
 |       // * If there is a thunk returned, thunk->hlo_instruction() equals the
 | ||||||
|       // input HloInstruction*.
 |       // input HloInstruction*.
 | ||||||
|  |       // * A returned thunk may contain other sub-thunks. A sub-thunk may or may
 | ||||||
|  |       // not have an associated hlo_instruction().
 | ||||||
|       TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString(); |       TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString(); | ||||||
|       if (!thunks->empty()) { |       if (!thunks->empty()) { | ||||||
|         auto thunk = std::move(thunks->front()); |         auto thunk = std::move(thunks->front()); | ||||||
| @ -603,6 +608,25 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( | |||||||
|     return cuda_compute_capability; |     return cuda_compute_capability; | ||||||
|   }(); |   }(); | ||||||
| 
 | 
 | ||||||
|  |   std::unique_ptr<HloProfileIndexMap> profile_index_map; | ||||||
|  |   std::unique_ptr<HloProfilePrinterData> profile_printer; | ||||||
|  | 
 | ||||||
|  |   if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { | ||||||
|  |     HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); | ||||||
|  |     cost_analysis.set_bytes_per_second( | ||||||
|  |         stream_exec->GetDeviceDescription().memory_bandwidth()); | ||||||
|  |     TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); | ||||||
|  |     VLOG(1) << "HLO memory read+written: " | ||||||
|  |             << tensorflow::strings::HumanReadableNumBytes( | ||||||
|  |                    cost_analysis.bytes_accessed()); | ||||||
|  |     if (module->config().hlo_profiling_enabled()) { | ||||||
|  |       profile_index_map = absl::make_unique<HloProfileIndexMap>(*module); | ||||||
|  |       profile_printer = | ||||||
|  |           CreateHloProfilePrinterData(*profile_index_map, cost_analysis, | ||||||
|  |                                       module->entry_computation()->name()); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   std::unique_ptr<llvm::Module> llvm_module; |   std::unique_ptr<llvm::Module> llvm_module; | ||||||
|   std::unique_ptr<BufferAssignment> buffer_assignment; |   std::unique_ptr<BufferAssignment> buffer_assignment; | ||||||
|   std::unique_ptr<ThunkSchedule> thunk_schedule; |   std::unique_ptr<ThunkSchedule> thunk_schedule; | ||||||
| @ -610,8 +634,8 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( | |||||||
|   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( |   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( | ||||||
|       module.get(), &llvm_context, target_triple_, data_layout_, |       module.get(), &llvm_context, target_triple_, data_layout_, | ||||||
|       stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, |       stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability, | ||||||
|       GetCanShareBuffer(), pointer_size_, &llvm_module, &buffer_assignment, |       GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module, | ||||||
|       &thunk_schedule)); |       &buffer_assignment, &thunk_schedule)); | ||||||
| 
 | 
 | ||||||
|   if (user_pre_optimization_hook_) { |   if (user_pre_optimization_hook_) { | ||||||
|     user_pre_optimization_hook_(*llvm_module); |     user_pre_optimization_hook_(*llvm_module); | ||||||
| @ -653,25 +677,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( | |||||||
|                             thunk_schedule->ToString()); |                             thunk_schedule->ToString()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   std::unique_ptr<HloProfileIndexMap> profile_index_map; |  | ||||||
|   std::unique_ptr<HloProfilePrinterData> profile_printer; |  | ||||||
| 
 |  | ||||||
|   if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { |  | ||||||
|     HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); |  | ||||||
|     cost_analysis.set_bytes_per_second( |  | ||||||
|         stream_exec->GetDeviceDescription().memory_bandwidth()); |  | ||||||
|     TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); |  | ||||||
|     VLOG(1) << "HLO memory read+written: " |  | ||||||
|             << tensorflow::strings::HumanReadableNumBytes( |  | ||||||
|                    cost_analysis.bytes_accessed()); |  | ||||||
|     if (module->config().hlo_profiling_enabled()) { |  | ||||||
|       profile_index_map = absl::make_unique<HloProfileIndexMap>(*module); |  | ||||||
|       profile_printer = |  | ||||||
|           CreateHloProfilePrinterData(*profile_index_map, cost_analysis, |  | ||||||
|                                       module->entry_computation()->name()); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   auto* gpu_executable = new GpuExecutable( |   auto* gpu_executable = new GpuExecutable( | ||||||
|       backend_result.first, backend_result.second, gpu_version, |       backend_result.first, backend_result.second, gpu_version, | ||||||
|       std::move(thunk_schedule), std::move(module), |       std::move(thunk_schedule), std::move(module), | ||||||
| @ -709,7 +714,8 @@ StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr( | |||||||
|   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( |   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl( | ||||||
|       hlo_module, llvm_context, target_triple, data_layout, platform_name, |       hlo_module, llvm_context, target_triple, data_layout, platform_name, | ||||||
|       gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, |       gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction, | ||||||
|       pointer_size, &llvm_module, &buffer_assignment, &thunk_schedule)); |       pointer_size, /*profile_index_map=*/nullptr, &llvm_module, | ||||||
|  |       &buffer_assignment, &thunk_schedule)); | ||||||
|   return llvm_module; |   return llvm_module; | ||||||
| } | } | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
|  | |||||||
| @ -23,7 +23,6 @@ limitations under the License. | |||||||
| #include "absl/memory/memory.h" | #include "absl/memory/memory.h" | ||||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||||
| #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" | #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" | ||||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |  | ||||||
| #include "tensorflow/compiler/xla/service/stream_pool.h" | #include "tensorflow/compiler/xla/service/stream_pool.h" | ||||||
| #include "tensorflow/core/platform/logging.h" | #include "tensorflow/core/platform/logging.h" | ||||||
| #include "tensorflow/core/platform/stream_executor_no_cuda.h" | #include "tensorflow/core/platform/stream_executor_no_cuda.h" | ||||||
| @ -97,26 +96,24 @@ void HloExecutionProfiler::StartHloInstruction() { | |||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void HloExecutionProfiler::FinishHloInstruction( | void HloExecutionProfiler::FinishHloInstruction(size_t index) { | ||||||
|     const HloInstruction* hlo_instruction) { |  | ||||||
|   if (do_profile_) { |   if (do_profile_) { | ||||||
|     hlo_instructions_.erase(hlo_instruction); |     indices_.erase(index); | ||||||
|     profile_->SetCyclesTakenBy( |     profile_->SetCyclesTakenBy(index, GetCyclesTaken(&timers_, sub_streams_, | ||||||
|         hlo_instruction, |                                                      stream_, clock_rate_ghz_)); | ||||||
|         GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<ScopedInstructionProfiler> | std::unique_ptr<ScopedInstructionProfiler> | ||||||
| HloExecutionProfiler::MakeScopedInstructionProfiler( | HloExecutionProfiler::MakeScopedInstructionProfiler( | ||||||
|     const HloInstruction* hlo_instruction) { |     absl::optional<int64> index) { | ||||||
|   if (do_profile_ && hlo_instruction != nullptr) { |   if (do_profile_ && index.has_value()) { | ||||||
|     // Make sure that we are not already measuring the time for the same
 |     // Make sure that we are not already measuring the time for the same
 | ||||||
|     // 'hlo_instruction'.
 |     // instruction.
 | ||||||
|     CHECK(hlo_instructions_.insert(hlo_instruction).second) |     // TODO(timshen): provide more useful printout.
 | ||||||
|         << hlo_instruction->name(); |     CHECK(indices_.insert(*index).second) << *index; | ||||||
|   } |   } | ||||||
|   return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction); |   return absl::make_unique<ScopedInstructionProfiler>(this, index); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
|  | |||||||
| @ -23,7 +23,6 @@ limitations under the License. | |||||||
| 
 | 
 | ||||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||||
| #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" | #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" | ||||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |  | ||||||
| #include "tensorflow/compiler/xla/service/stream_pool.h" | #include "tensorflow/compiler/xla/service/stream_pool.h" | ||||||
| #include "tensorflow/core/platform/stream_executor_no_cuda.h" | #include "tensorflow/core/platform/stream_executor_no_cuda.h" | ||||||
| 
 | 
 | ||||||
| @ -58,14 +57,17 @@ class HloExecutionProfiler { | |||||||
|   void StartHloInstruction(); |   void StartHloInstruction(); | ||||||
| 
 | 
 | ||||||
|   // If profiling is enabled, stops the per-operation timer and records the time
 |   // If profiling is enabled, stops the per-operation timer and records the time
 | ||||||
|   // that the hlo_instruction took to execute in the profile.
 |   // that at `profile_index`. Profile indices can be looked up from
 | ||||||
|   void FinishHloInstruction(const HloInstruction* hlo_instruction); |   // HloProfileIndexMap.
 | ||||||
|  |   void FinishHloInstruction(size_t profile_index); | ||||||
| 
 | 
 | ||||||
|   // Returns a ScopedInstructionProfiler and triggers a call to
 |   // Returns a ScopedInstructionProfiler and triggers a call to
 | ||||||
|   // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
 |   // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
 | ||||||
|   // out of scope, it triggers a call to FinishHloInstruction().
 |   // out of scope, it triggers a call to FinishHloInstruction().
 | ||||||
|  |   //
 | ||||||
|  |   // If profile_index < 0, it results in a no-op.
 | ||||||
|   std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler( |   std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler( | ||||||
|       const HloInstruction* hlo_instruction); |       absl::optional<int64> profile_index); | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   const bool do_profile_; |   const bool do_profile_; | ||||||
| @ -77,7 +79,7 @@ class HloExecutionProfiler { | |||||||
|   std::stack<std::unique_ptr<se::Timer>> timers_; |   std::stack<std::unique_ptr<se::Timer>> timers_; | ||||||
|   // Contains the HLO instructions for which we are currently measuring the
 |   // Contains the HLO instructions for which we are currently measuring the
 | ||||||
|   // time.
 |   // time.
 | ||||||
|   std::unordered_set<const HloInstruction*> hlo_instructions_; |   std::unordered_set<size_t> indices_; | ||||||
|   bool finished_execution_ = false; |   bool finished_execution_ = false; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| @ -87,21 +89,21 @@ class HloExecutionProfiler { | |||||||
| class ScopedInstructionProfiler { | class ScopedInstructionProfiler { | ||||||
|  public: |  public: | ||||||
|   ScopedInstructionProfiler(HloExecutionProfiler* profiler, |   ScopedInstructionProfiler(HloExecutionProfiler* profiler, | ||||||
|                             const HloInstruction* hlo_instruction) |                             absl::optional<int64> index) | ||||||
|       : profiler_(profiler), hlo_instruction_(hlo_instruction) { |       : profiler_(profiler), index_(index) { | ||||||
|     if (hlo_instruction != nullptr) { |     if (index_.has_value()) { | ||||||
|       profiler->StartHloInstruction(); |       profiler->StartHloInstruction(); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   ~ScopedInstructionProfiler() { |   ~ScopedInstructionProfiler() { | ||||||
|     if (hlo_instruction_ != nullptr) { |     if (index_.has_value()) { | ||||||
|       profiler_->FinishHloInstruction(hlo_instruction_); |       profiler_->FinishHloInstruction(*index_); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   HloExecutionProfiler* profiler_; |   HloExecutionProfiler* profiler_; | ||||||
|   const HloInstruction* hlo_instruction_; |   absl::optional<int64> index_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
|  | |||||||
| @ -23,9 +23,9 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| InfeedThunk::InfeedThunk( | InfeedThunk::InfeedThunk( | ||||||
|     const ShapeTree<BufferAllocation::Slice>& infeed_slices, |     ThunkInfo thunk_info, | ||||||
|     const HloInstruction* hlo_instruction) |     const ShapeTree<BufferAllocation::Slice>& infeed_slices) | ||||||
|     : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} |     : Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {} | ||||||
| 
 | 
 | ||||||
| Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { | Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|   auto& stream = *params.stream; |   auto& stream = *params.stream; | ||||||
| @ -34,7 +34,7 @@ Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); |   VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   ShapeTree<InfeedBuffer> infeed_buffers = |   ShapeTree<InfeedBuffer> infeed_buffers = | ||||||
|       GetOrCreateInfeedManager()->BlockingGetNextDestination(); |       GetOrCreateInfeedManager()->BlockingGetNextDestination(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,8 +34,8 @@ class InfeedThunk : public Thunk { | |||||||
|  public: |  public: | ||||||
|   // Constructs a InfeedThunk that copies data from the on-device
 |   // Constructs a InfeedThunk that copies data from the on-device
 | ||||||
|   // infeed queue into the buffers in the given shape tree.
 |   // infeed queue into the buffers in the given shape tree.
 | ||||||
|   InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices, |   InfeedThunk(ThunkInfo thunk_info, | ||||||
|               const HloInstruction* hlo_instruction); |               const ShapeTree<BufferAllocation::Slice>& infeed_slices); | ||||||
| 
 | 
 | ||||||
|   InfeedThunk(const InfeedThunk&) = delete; |   InfeedThunk(const InfeedThunk&) = delete; | ||||||
|   InfeedThunk& operator=(const InfeedThunk&) = delete; |   InfeedThunk& operator=(const InfeedThunk&) = delete; | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ limitations under the License. | |||||||
| #include "llvm/IR/Module.h" | #include "llvm/IR/Module.h" | ||||||
| #include "tensorflow/compiler/xla/service/buffer_assignment.h" | #include "tensorflow/compiler/xla/service/buffer_assignment.h" | ||||||
| #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" | #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" | ||||||
|  | #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" | ||||||
| #include "tensorflow/compiler/xla/service/name_uniquer.h" | #include "tensorflow/compiler/xla/service/name_uniquer.h" | ||||||
| 
 | 
 | ||||||
| namespace xla { | namespace xla { | ||||||
| @ -33,12 +34,13 @@ class IrEmitterContext { | |||||||
|       const HloModule* hlo_module, const BufferAssignment* buffer_assignment, |       const HloModule* hlo_module, const BufferAssignment* buffer_assignment, | ||||||
|       std::string platform_name, GpuDeviceInfo gpu_device_info, |       std::string platform_name, GpuDeviceInfo gpu_device_info, | ||||||
|       absl::optional<CudaComputeCapability> cuda_compute_capability, |       absl::optional<CudaComputeCapability> cuda_compute_capability, | ||||||
|       llvm::Module* llvm_module) |       const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module) | ||||||
|       : hlo_module_(hlo_module), |       : hlo_module_(hlo_module), | ||||||
|         buffer_assignment_(buffer_assignment), |         buffer_assignment_(buffer_assignment), | ||||||
|         platform_name_(std::move(platform_name)), |         platform_name_(std::move(platform_name)), | ||||||
|         gpu_device_info_(gpu_device_info), |         gpu_device_info_(gpu_device_info), | ||||||
|         cuda_compute_capability_(cuda_compute_capability), |         cuda_compute_capability_(cuda_compute_capability), | ||||||
|  |         profile_index_map_(profile_index_map), | ||||||
|         llvm_module_(llvm_module) {} |         llvm_module_(llvm_module) {} | ||||||
|   // Disallow copy and assign.
 |   // Disallow copy and assign.
 | ||||||
|   IrEmitterContext(const IrEmitterContext&) = delete; |   IrEmitterContext(const IrEmitterContext&) = delete; | ||||||
| @ -54,6 +56,7 @@ class IrEmitterContext { | |||||||
|   absl::optional<CudaComputeCapability> cuda_compute_capability() const { |   absl::optional<CudaComputeCapability> cuda_compute_capability() const { | ||||||
|     return cuda_compute_capability_; |     return cuda_compute_capability_; | ||||||
|   } |   } | ||||||
|  |   const HloProfileIndexMap* profile_index_map() { return profile_index_map_; } | ||||||
|   llvm::Module* llvm_module() { return llvm_module_; } |   llvm::Module* llvm_module() { return llvm_module_; } | ||||||
|   NameUniquer* name_uniquer() { return &name_uniquer_; } |   NameUniquer* name_uniquer() { return &name_uniquer_; } | ||||||
| 
 | 
 | ||||||
| @ -63,6 +66,7 @@ class IrEmitterContext { | |||||||
|   std::string platform_name_; |   std::string platform_name_; | ||||||
|   GpuDeviceInfo gpu_device_info_; |   GpuDeviceInfo gpu_device_info_; | ||||||
|   absl::optional<CudaComputeCapability> cuda_compute_capability_; |   absl::optional<CudaComputeCapability> cuda_compute_capability_; | ||||||
|  |   const HloProfileIndexMap* profile_index_map_; | ||||||
|   llvm::Module* llvm_module_; |   llvm::Module* llvm_module_; | ||||||
|   NameUniquer name_uniquer_; |   NameUniquer name_uniquer_; | ||||||
| }; | }; | ||||||
|  | |||||||
| @ -652,8 +652,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { | |||||||
|               /*updates_gen=*/ |               /*updates_gen=*/ | ||||||
|               scatter_fused_emitter.GetGenerator(root->operand(2)))); |               scatter_fused_emitter.GetGenerator(root->operand(2)))); | ||||||
|         } |         } | ||||||
|         AddThunkToThunkSequence( |         AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( | ||||||
|             absl::make_unique<SequentialThunk>(std::move(thunks), fusion)); |             GetThunkInfo(fusion), std::move(thunks))); | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|       } |       } | ||||||
|       // In the case of root tuple, it can be either reduce or slice input
 |       // In the case of root tuple, it can be either reduce or slice input
 | ||||||
| @ -739,10 +739,11 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { | |||||||
|     auto destination_buffer = GetAllocationSlice(*copy); |     auto destination_buffer = GetAllocationSlice(*copy); | ||||||
|     if (operand_buffer != destination_buffer) { |     if (operand_buffer != destination_buffer) { | ||||||
|       AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( |       AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |           GetThunkInfo(copy), | ||||||
|           /*source_address=*/operand_buffer, |           /*source_address=*/operand_buffer, | ||||||
|           /*destination_buffer=*/destination_buffer, |           /*destination_buffer=*/destination_buffer, | ||||||
|           /*mem_size=*/ |           /*mem_size=*/ | ||||||
|           ByteSizeOf(copy->operand(0)->shape()), copy)); |           ByteSizeOf(copy->operand(0)->shape()))); | ||||||
|     } |     } | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| @ -816,7 +817,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { | |||||||
|       tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); |       tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); | ||||||
|     } |     } | ||||||
|     AddThunkToThunkSequence(absl::make_unique<TupleThunk>( |     AddThunkToThunkSequence(absl::make_unique<TupleThunk>( | ||||||
|         tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); |         GetThunkInfo(tuple), tuple_element_buffers, | ||||||
|  |         GetAllocationSlice(*tuple))); | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
|   AddThunkToThunkSequence( |   AddThunkToThunkSequence( | ||||||
| @ -848,7 +850,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( | |||||||
|   thunks.push_back(BuildKernelThunk(select_and_scatter, |   thunks.push_back(BuildKernelThunk(select_and_scatter, | ||||||
|                                     /*implements_whole_instruction=*/false)); |                                     /*implements_whole_instruction=*/false)); | ||||||
|   std::unique_ptr<SequentialThunk> select_and_scatter_thunk = |   std::unique_ptr<SequentialThunk> select_and_scatter_thunk = | ||||||
|       absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter); |       absl::make_unique<SequentialThunk>(GetThunkInfo(select_and_scatter), | ||||||
|  |                                          std::move(thunks)); | ||||||
| 
 | 
 | ||||||
|   // TODO(b/31410564): Implement dilation rate for select-and-scatter.
 |   // TODO(b/31410564): Implement dilation rate for select-and-scatter.
 | ||||||
|   if (window_util::HasDilation(window)) { |   if (window_util::HasDilation(window)) { | ||||||
| @ -1082,10 +1085,10 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { | |||||||
|   auto destination_buffer = GetAllocationSlice(*scatter); |   auto destination_buffer = GetAllocationSlice(*scatter); | ||||||
|   if (operand_buffer != destination_buffer) { |   if (operand_buffer != destination_buffer) { | ||||||
|     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |         Thunk::ThunkInfo(), | ||||||
|         /*source_address=*/operand_buffer, |         /*source_address=*/operand_buffer, | ||||||
|         /*destination_buffer=*/destination_buffer, |         /*destination_buffer=*/destination_buffer, | ||||||
|         /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), |         /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()))); | ||||||
|         /*hlo_instruction=*/nullptr)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   thunks.push_back( |   thunks.push_back( | ||||||
| @ -1109,8 +1112,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { | |||||||
|   if (thunks.size() == 1) { |   if (thunks.size() == 1) { | ||||||
|     AddThunkToThunkSequence(std::move(thunks[0])); |     AddThunkToThunkSequence(std::move(thunks[0])); | ||||||
|   } else { |   } else { | ||||||
|     AddThunkToThunkSequence( |     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( | ||||||
|         absl::make_unique<SequentialThunk>(std::move(thunks), scatter)); |         GetThunkInfo(scatter), std::move(thunks))); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| @ -1282,10 +1285,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { | |||||||
|       // TODO(b/26783907): Figure out why we never seem to share buffers for
 |       // TODO(b/26783907): Figure out why we never seem to share buffers for
 | ||||||
|       // key/value sort.
 |       // key/value sort.
 | ||||||
|       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |           Thunk::ThunkInfo(), | ||||||
|           /*source_address=*/source_address, |           /*source_address=*/source_address, | ||||||
|           /*destination_buffer=*/destination_buffer, |           /*destination_buffer=*/destination_buffer, | ||||||
|           /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()), |           /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); | ||||||
|           nullptr)); |  | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -1419,8 +1422,8 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { | |||||||
|     TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); |     TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   AddThunkToThunkSequence( |   AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( | ||||||
|       absl::make_unique<SequentialThunk>(std::move(thunks), sort)); |       GetThunkInfo(sort), std::move(thunks))); | ||||||
|   if (sort->operand_count() > 1) { |   if (sort->operand_count() > 1) { | ||||||
|     // Emit the tuple as part of the last stage of sorting.
 |     // Emit the tuple as part of the last stage of sorting.
 | ||||||
|     // We are currently in the block sorted.in_bounds.after.
 |     // We are currently in the block sorted.in_bounds.after.
 | ||||||
| @ -1438,14 +1441,15 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) { | Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) { | ||||||
|   AddThunkToThunkSequence( |   AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>( | ||||||
|       absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo)); |       GetThunkInfo(hlo), GetAllocationSlice(*hlo))); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { | Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { | ||||||
|   AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>( |   AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>( | ||||||
|       GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo)); |       GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), | ||||||
|  |       GetAllocationSlice(*hlo))); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -1478,15 +1482,16 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { | |||||||
|       tuple_element_buffers.push_back(buffers[i].destination_buffer); |       tuple_element_buffers.push_back(buffers[i].destination_buffer); | ||||||
|     } |     } | ||||||
|     auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>( |     auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>( | ||||||
|  |         GetThunkInfo(crs), | ||||||
|         /*replica_count=*/hlo_module_config_.replica_count(), |         /*replica_count=*/hlo_module_config_.replica_count(), | ||||||
|         /*buffers=*/std::move(buffers), crs); |         /*buffers=*/std::move(buffers)); | ||||||
|     if (crs->shape().IsTuple()) { |     if (crs->shape().IsTuple()) { | ||||||
|       std::vector<std::unique_ptr<Thunk>> thunks; |       std::vector<std::unique_ptr<Thunk>> thunks; | ||||||
|       thunks.push_back(std::move(all_reduce_thunk)); |       thunks.push_back(std::move(all_reduce_thunk)); | ||||||
|       thunks.push_back(absl::make_unique<TupleThunk>( |       thunks.push_back(absl::make_unique<TupleThunk>( | ||||||
|           tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); |           Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs))); | ||||||
|       AddThunkToThunkSequence( |       AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( | ||||||
|           absl::make_unique<SequentialThunk>(std::move(thunks), crs)); |           GetThunkInfo(crs), std::move(thunks))); | ||||||
|     } else { |     } else { | ||||||
|       AddThunkToThunkSequence(std::move(all_reduce_thunk)); |       AddThunkToThunkSequence(std::move(all_reduce_thunk)); | ||||||
|     } |     } | ||||||
| @ -1520,9 +1525,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { | |||||||
|     CHECK(crs->operand(0)->shape().IsArray()) |     CHECK(crs->operand(0)->shape().IsArray()) | ||||||
|         << "Operands to all-reduce must be arrays: " << crs->ToString(); |         << "Operands to all-reduce must be arrays: " << crs->ToString(); | ||||||
|     AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( |     AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |         GetThunkInfo(crs), | ||||||
|         /*source_address=*/GetAllocationSlice(*crs->operand(0)), |         /*source_address=*/GetAllocationSlice(*crs->operand(0)), | ||||||
|         /*destination_buffer=*/GetAllocationSlice(*crs), |         /*destination_buffer=*/GetAllocationSlice(*crs), | ||||||
|         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); |         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()))); | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -1535,16 +1541,17 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { | |||||||
|                                         .GetUniqueSlice(crs, {i}) |                                         .GetUniqueSlice(crs, {i}) | ||||||
|                                         .ValueOrDie()); |                                         .ValueOrDie()); | ||||||
|     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |         Thunk::ThunkInfo(), | ||||||
|         /*source_address=*/GetAllocationSlice(*crs->operand(i)), |         /*source_address=*/GetAllocationSlice(*crs->operand(i)), | ||||||
|         /*destination_buffer=*/tuple_element_buffers.back(), |         /*destination_buffer=*/tuple_element_buffers.back(), | ||||||
|         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); |         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()))); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Output a tuple of the buffers above.
 |   // Output a tuple of the buffers above.
 | ||||||
|   thunks.push_back(absl::make_unique<TupleThunk>( |   thunks.push_back(absl::make_unique<TupleThunk>( | ||||||
|       tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); |       Thunk::ThunkInfo(), tuple_element_buffers, GetAllocationSlice(*crs))); | ||||||
|   AddThunkToThunkSequence( |   AddThunkToThunkSequence( | ||||||
|       absl::make_unique<SequentialThunk>(std::move(thunks), crs)); |       absl::make_unique<SequentialThunk>(GetThunkInfo(crs), std::move(thunks))); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -1787,8 +1794,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   return absl::make_unique<KernelThunk>( |   return absl::make_unique<KernelThunk>( | ||||||
|       non_constant_buffers, std::string(kernel->getName()), |       implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(), | ||||||
|       implements_whole_instruction ? inst : nullptr); |       non_constant_buffers, std::string(kernel->getName())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( | StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( | ||||||
| @ -1838,8 +1845,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( | |||||||
|     absl::Span<const uint8> literal_bytes( |     absl::Span<const uint8> literal_bytes( | ||||||
|         reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); |         reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); | ||||||
|     if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { |     if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { | ||||||
|       return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index), |       return {absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), | ||||||
|                                               nullptr)}; |                                               GetAllocationSlice(*hlo, index))}; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
 |     // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
 | ||||||
| @ -1857,7 +1864,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( | |||||||
|       } |       } | ||||||
|       uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); |       uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); | ||||||
|       return {absl::make_unique<Memset32BitValueThunk>( |       return {absl::make_unique<Memset32BitValueThunk>( | ||||||
|           pattern32, GetAllocationSlice(*hlo, index), nullptr)}; |           Thunk::ThunkInfo(), pattern32, GetAllocationSlice(*hlo, index))}; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
 |     // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
 | ||||||
| @ -1868,7 +1875,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( | |||||||
|       uint32 word; |       uint32 word; | ||||||
|       memcpy(&word, literal_bytes.data(), sizeof(word)); |       memcpy(&word, literal_bytes.data(), sizeof(word)); | ||||||
|       return {absl::make_unique<Memset32BitValueThunk>( |       return {absl::make_unique<Memset32BitValueThunk>( | ||||||
|           word, GetAllocationSlice(*hlo, index), nullptr)}; |           Thunk::ThunkInfo(), word, GetAllocationSlice(*hlo, index))}; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -2014,9 +2021,10 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk( | |||||||
|   TF_CHECK_OK(body->Accept(&ir_emitter_body)); |   TF_CHECK_OK(body->Accept(&ir_emitter_body)); | ||||||
| 
 | 
 | ||||||
|   return absl::make_unique<WhileThunk>( |   return absl::make_unique<WhileThunk>( | ||||||
|  |       GetThunkInfo(hlo), | ||||||
|       GetAllocationSlice(*condition->root_instruction()),  // cond result
 |       GetAllocationSlice(*condition->root_instruction()),  // cond result
 | ||||||
|       ir_emitter_condition.ConsumeThunkSequence(), |       ir_emitter_condition.ConsumeThunkSequence(), | ||||||
|       ir_emitter_body.ConsumeThunkSequence(), hlo); |       ir_emitter_body.ConsumeThunkSequence()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk( | std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk( | ||||||
| @ -2031,8 +2039,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk( | |||||||
|                                     ir_emitter_context_); |                                     ir_emitter_context_); | ||||||
|   TF_CHECK_OK(body->Accept(&ir_emitter_body)); |   TF_CHECK_OK(body->Accept(&ir_emitter_body)); | ||||||
| 
 | 
 | ||||||
|   return absl::make_unique<ForThunk>( |   return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit, | ||||||
|       loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); |                                      ir_emitter_body.ConsumeThunkSequence()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( | std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( | ||||||
| @ -2054,8 +2062,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   return absl::make_unique<ConditionalThunk>( |   return absl::make_unique<ConditionalThunk>( | ||||||
|       GetAllocationSlice(*hlo->operand(0)), branch_operands, |       GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, | ||||||
|       std::move(branch_thunks), hlo); |       std::move(branch_thunks)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status IrEmitterUnnested::EmitTargetElementLoopInThunk( | Status IrEmitterUnnested::EmitTargetElementLoopInThunk( | ||||||
| @ -3589,8 +3597,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( | |||||||
|                          ir_emitter_context_->llvm_module()); |                          ir_emitter_context_->llvm_module()); | ||||||
| 
 | 
 | ||||||
|   thunks.push_back(std::move(kernel_thunk)); |   thunks.push_back(std::move(kernel_thunk)); | ||||||
|   auto sequential_thunk = |   auto sequential_thunk = absl::make_unique<SequentialThunk>( | ||||||
|       absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo); |       GetThunkInfo(unnested_hlo), std::move(thunks)); | ||||||
|   AddThunkToThunkSequence(std::move(sequential_thunk)); |   AddThunkToThunkSequence(std::move(sequential_thunk)); | ||||||
| 
 | 
 | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| @ -3757,5 +3765,15 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( | |||||||
|   return emit_status; |   return emit_status; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo( | ||||||
|  |     const HloInstruction* hlo) const { | ||||||
|  |   auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo); | ||||||
|  |   if (const auto* index_map = ir_emitter_context_->profile_index_map()) { | ||||||
|  |     info.profile_index.emplace( | ||||||
|  |         static_cast<int64>(index_map->GetProfileIndexFor(*hlo))); | ||||||
|  |   } | ||||||
|  |   return info; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
| }  // namespace xla
 | }  // namespace xla
 | ||||||
|  | |||||||
| @ -548,6 +548,8 @@ class IrEmitterUnnested : public IrEmitter, | |||||||
|   // Returns the last generated thunk.
 |   // Returns the last generated thunk.
 | ||||||
|   Thunk* LastThunk() const { return thunk_sequence_.back().get(); } |   Thunk* LastThunk() const { return thunk_sequence_.back().get(); } | ||||||
| 
 | 
 | ||||||
|  |   Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override; | ||||||
|  | 
 | ||||||
|   // The thunk sequence this IrEmitter generates for the input computation.
 |   // The thunk sequence this IrEmitter generates for the input computation.
 | ||||||
|   ThunkSequence thunk_sequence_; |   ThunkSequence thunk_sequence_; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -33,10 +33,10 @@ limitations under the License. | |||||||
| namespace xla { | namespace xla { | ||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args, | KernelThunk::KernelThunk(ThunkInfo thunk_info, | ||||||
|                          const string& kernel_name, |                          absl::Span<const BufferAllocation* const> args, | ||||||
|                          const HloInstruction* hlo_instruction) |                          const string& kernel_name) | ||||||
|     : Thunk(Kind::kKernel, hlo_instruction), |     : Thunk(Kind::kKernel, thunk_info), | ||||||
|       args_(args.begin(), args.end()), |       args_(args.begin(), args.end()), | ||||||
|       kernel_name_(kernel_name) {} |       kernel_name_(kernel_name) {} | ||||||
| 
 | 
 | ||||||
| @ -114,7 +114,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   return ExecuteKernelOnStream(*kernel, buffer_args, |   return ExecuteKernelOnStream(*kernel, buffer_args, | ||||||
|                                launch_dimensions.threads_per_block(), |                                launch_dimensions.threads_per_block(), | ||||||
|                                launch_dimensions.block_count(), params.stream); |                                launch_dimensions.block_count(), params.stream); | ||||||
|  | |||||||
| @ -47,8 +47,9 @@ class KernelThunk : public Thunk { | |||||||
|   // Constructs a thunk for the given kernel.
 |   // Constructs a thunk for the given kernel.
 | ||||||
|   //
 |   //
 | ||||||
|   // `hlo_instruction` is as in Thunk. Other arguments are as the class members.
 |   // `hlo_instruction` is as in Thunk. Other arguments are as the class members.
 | ||||||
|   KernelThunk(absl::Span<const BufferAllocation* const> args, |   KernelThunk(ThunkInfo thunk_info, | ||||||
|               const string& kernel_name, const HloInstruction* hlo_instruction); |               absl::Span<const BufferAllocation* const> args, | ||||||
|  |               const string& kernel_name); | ||||||
|   KernelThunk(const KernelThunk&) = delete; |   KernelThunk(const KernelThunk&) = delete; | ||||||
|   KernelThunk& operator=(const KernelThunk&) = delete; |   KernelThunk& operator=(const KernelThunk&) = delete; | ||||||
|   ~KernelThunk() override = default; |   ~KernelThunk() override = default; | ||||||
|  | |||||||
| @ -25,7 +25,7 @@ Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   se::DeviceMemoryBase dest_data = |   se::DeviceMemoryBase dest_data = | ||||||
|       params.buffer_allocations->GetDeviceAddress(dest_); |       params.buffer_allocations->GetDeviceAddress(dest_); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   params.stream->ThenMemZero(&dest_data, dest_data.size()); |   params.stream->ThenMemZero(&dest_data, dest_data.size()); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| @ -34,7 +34,7 @@ Status Memset32BitValueThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   se::DeviceMemoryBase dest_data = |   se::DeviceMemoryBase dest_data = | ||||||
|       params.buffer_allocations->GetDeviceAddress(dest_); |       params.buffer_allocations->GetDeviceAddress(dest_); | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   params.stream->ThenMemset32(&dest_data, value_, dest_data.size()); |   params.stream->ThenMemset32(&dest_data, value_, dest_data.size()); | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -32,9 +32,9 @@ namespace gpu { | |||||||
| // Thunk that zeroes out a given chunk of memory.
 | // Thunk that zeroes out a given chunk of memory.
 | ||||||
| class MemzeroThunk : public Thunk { | class MemzeroThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   explicit MemzeroThunk(const BufferAllocation::Slice& dest, |   explicit MemzeroThunk(ThunkInfo thunk_info, | ||||||
|                         const HloInstruction* hlo) |                         const BufferAllocation::Slice& dest) | ||||||
|       : Thunk(Kind::kMemzero, hlo), dest_(dest) {} |       : Thunk(Kind::kMemzero, thunk_info), dest_(dest) {} | ||||||
| 
 | 
 | ||||||
|   Status ExecuteOnStream(const ExecuteParams& params) override; |   Status ExecuteOnStream(const ExecuteParams& params) override; | ||||||
| 
 | 
 | ||||||
| @ -46,10 +46,11 @@ class MemzeroThunk : public Thunk { | |||||||
| // destination chunk must have size divisible by 32 bits.
 | // destination chunk must have size divisible by 32 bits.
 | ||||||
| class Memset32BitValueThunk : public Thunk { | class Memset32BitValueThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   explicit Memset32BitValueThunk(uint32 value, |   explicit Memset32BitValueThunk(ThunkInfo thunk_info, uint32 value, | ||||||
|                                  const BufferAllocation::Slice& dest, |                                  const BufferAllocation::Slice& dest) | ||||||
|                                  const HloInstruction* hlo) |       : Thunk(Kind::kMemset32BitValue, thunk_info), | ||||||
|       : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} |         value_(value), | ||||||
|  |         dest_(dest) {} | ||||||
| 
 | 
 | ||||||
|   Status ExecuteOnStream(const ExecuteParams& params) override; |   Status ExecuteOnStream(const ExecuteParams& params) override; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -541,9 +541,9 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| NcclAllReduceThunk::NcclAllReduceThunk( | NcclAllReduceThunk::NcclAllReduceThunk( | ||||||
|     int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers, |     ThunkInfo thunk_info, int64 replica_count, | ||||||
|     const HloInstruction* all_reduce) |     std::vector<NcclAllReduceThunk::Buffer> buffers) | ||||||
|     : Thunk(Thunk::kNcclAllReduce, all_reduce), |     : Thunk(Thunk::kNcclAllReduce, thunk_info), | ||||||
|       replica_count_(replica_count), |       replica_count_(replica_count), | ||||||
|       buffers_(std::move(buffers)), |       buffers_(std::move(buffers)), | ||||||
|       aux_data_(absl::make_unique<AuxData>()) { |       aux_data_(absl::make_unique<AuxData>()) { | ||||||
| @ -555,7 +555,7 @@ NcclAllReduceThunk::NcclAllReduceThunk( | |||||||
| Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { | Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|   VLOG(1) << "Starting NcclAllReduceThunk."; |   VLOG(1) << "Starting NcclAllReduceThunk."; | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
| 
 | 
 | ||||||
|   auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction()); |   auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction()); | ||||||
|   int64 local_device_ordinal = params.stream->parent()->device_ordinal(); |   int64 local_device_ordinal = params.stream->parent()->device_ordinal(); | ||||||
|  | |||||||
| @ -56,8 +56,8 @@ class NcclAllReduceThunk : public Thunk { | |||||||
|     BufferAllocation::Slice source_buffer; |     BufferAllocation::Slice source_buffer; | ||||||
|     BufferAllocation::Slice destination_buffer; |     BufferAllocation::Slice destination_buffer; | ||||||
|   }; |   }; | ||||||
|   NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers, |   NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count, | ||||||
|                      const HloInstruction* all_reduce); |                      std::vector<Buffer> buffers); | ||||||
|   ~NcclAllReduceThunk() override; |   ~NcclAllReduceThunk() override; | ||||||
| 
 | 
 | ||||||
|   Status ExecuteOnStream(const ExecuteParams& params) override; |   Status ExecuteOnStream(const ExecuteParams& params) override; | ||||||
|  | |||||||
| @ -23,9 +23,9 @@ limitations under the License. | |||||||
| namespace xla { | namespace xla { | ||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices, | OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, | ||||||
|                            const HloInstruction* hlo_instruction) |                            ShapeTree<BufferAllocation::Slice> outfeed_slices) | ||||||
|     : Thunk(Kind::kOutfeed, hlo_instruction), |     : Thunk(Kind::kOutfeed, thunk_info), | ||||||
|       outfeed_slices_(std::move(outfeed_slices)) {} |       outfeed_slices_(std::move(outfeed_slices)) {} | ||||||
| 
 | 
 | ||||||
| Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { | Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
| @ -35,7 +35,7 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); |   VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); |   OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); | ||||||
|   ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers = |   ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers = | ||||||
|       outfeed_manager->BlockingGetNextDestination(); |       outfeed_manager->BlockingGetNextDestination(); | ||||||
|  | |||||||
| @ -32,8 +32,8 @@ class OutfeedThunk : public Thunk { | |||||||
|  public: |  public: | ||||||
|   // Constructs a OutfeedThunk that copies data to the host-side
 |   // Constructs a OutfeedThunk that copies data to the host-side
 | ||||||
|   // outfeed queue from the buffers in the given shape tree.
 |   // outfeed queue from the buffers in the given shape tree.
 | ||||||
|   OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices, |   OutfeedThunk(ThunkInfo thunk_info, | ||||||
|                const HloInstruction* hlo_instruction); |                ShapeTree<BufferAllocation::Slice> outfeed_slices); | ||||||
| 
 | 
 | ||||||
|   OutfeedThunk(const OutfeedThunk&) = delete; |   OutfeedThunk(const OutfeedThunk&) = delete; | ||||||
|   OutfeedThunk& operator=(const OutfeedThunk&) = delete; |   OutfeedThunk& operator=(const OutfeedThunk&) = delete; | ||||||
|  | |||||||
| @ -18,13 +18,13 @@ limitations under the License. | |||||||
| namespace xla { | namespace xla { | ||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| ReplicaIdThunk::ReplicaIdThunk(const BufferAllocation::Slice& dest, | ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info, | ||||||
|                                const HloInstruction* instr) |                                const BufferAllocation::Slice& dest) | ||||||
|     : Thunk(Kind::kReplicaId, instr), dest_(dest) {} |     : Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {} | ||||||
| 
 | 
 | ||||||
| Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) { | Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
| 
 | 
 | ||||||
|   auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_); |   auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_); | ||||||
|   TF_ASSIGN_OR_RETURN(int replica_id, |   TF_ASSIGN_OR_RETURN(int replica_id, | ||||||
|  | |||||||
| @ -26,8 +26,7 @@ namespace gpu { | |||||||
| // Thunk that implements the ReplicaId HLO.
 | // Thunk that implements the ReplicaId HLO.
 | ||||||
| class ReplicaIdThunk : public Thunk { | class ReplicaIdThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   ReplicaIdThunk(const BufferAllocation::Slice& dest, |   ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest); | ||||||
|                  const HloInstruction* instr); |  | ||||||
| 
 | 
 | ||||||
|   Status ExecuteOnStream(const ExecuteParams& params) override; |   Status ExecuteOnStream(const ExecuteParams& params) override; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -24,9 +24,9 @@ namespace gpu { | |||||||
| 
 | 
 | ||||||
| using ::tensorflow::profiler::ScopedAnnotation; | using ::tensorflow::profiler::ScopedAnnotation; | ||||||
| 
 | 
 | ||||||
| SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks, | SequentialThunk::SequentialThunk(ThunkInfo thunk_info, | ||||||
|                                  const HloInstruction* hlo) |                                  std::vector<std::unique_ptr<Thunk>> thunks) | ||||||
|     : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} |     : Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {} | ||||||
| 
 | 
 | ||||||
| void SequentialThunk::ComputeAnnotations() { | void SequentialThunk::ComputeAnnotations() { | ||||||
|   for (const auto& thunk : thunks_) { |   for (const auto& thunk : thunks_) { | ||||||
| @ -44,7 +44,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, | |||||||
| 
 | 
 | ||||||
| Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { | Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   for (const auto& thunk : thunks_) { |   for (const auto& thunk : thunks_) { | ||||||
|     ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); |     ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); | ||||||
|     TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); |     TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); | ||||||
|  | |||||||
| @ -32,8 +32,8 @@ namespace gpu { | |||||||
| // require multiple kernel launches or library calls.
 | // require multiple kernel launches or library calls.
 | ||||||
| class SequentialThunk : public Thunk { | class SequentialThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   SequentialThunk(std::vector<std::unique_ptr<Thunk>> thunks, |   SequentialThunk(ThunkInfo thunk_info, | ||||||
|                   const HloInstruction* hlo); |                   std::vector<std::unique_ptr<Thunk>> thunks); | ||||||
|   SequentialThunk(const SequentialThunk&) = delete; |   SequentialThunk(const SequentialThunk&) = delete; | ||||||
|   SequentialThunk& operator=(const SequentialThunk&) = delete; |   SequentialThunk& operator=(const SequentialThunk&) = delete; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -68,13 +68,21 @@ class Thunk { | |||||||
|     kWhile, |     kWhile, | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|  |   struct ThunkInfo { | ||||||
|  |     const HloInstruction* hlo_instruction = nullptr; | ||||||
|  |     absl::optional<int64> profile_index; | ||||||
|  |     // TODO(timshen): Remove hlo_instruction and add name(),
 | ||||||
|  |     // profile_annotation() here.
 | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|   // The hlo_instruction argument is meant to be the instruction this thunk was
 |   // The hlo_instruction argument is meant to be the instruction this thunk was
 | ||||||
|   // generated from, but Thunk never uses this argument other than to save it
 |   // generated from, but Thunk never uses this argument other than to save it
 | ||||||
|   // to Thunk::hlo_instruction, so it can be null.
 |   // to Thunk::hlo_instruction, so it can be null.
 | ||||||
|   explicit Thunk(Kind kind, const HloInstruction* hlo_instruction) |   explicit Thunk(Kind kind, ThunkInfo thunk_info) | ||||||
|       : kind_(kind), |       : kind_(kind), | ||||||
|         hlo_instruction_(hlo_instruction), |         hlo_instruction_(thunk_info.hlo_instruction), | ||||||
|         name_(hlo_instruction_ ? hlo_instruction_->name() : "") {} |         name_(hlo_instruction_ ? hlo_instruction_->name() : ""), | ||||||
|  |         profile_index_(thunk_info.profile_index) {} | ||||||
|   virtual ~Thunk() {} |   virtual ~Thunk() {} | ||||||
|   Thunk(const Thunk&) = delete; |   Thunk(const Thunk&) = delete; | ||||||
|   Thunk& operator=(const Thunk&) = delete; |   Thunk& operator=(const Thunk&) = delete; | ||||||
| @ -128,6 +136,8 @@ class Thunk { | |||||||
|  protected: |  protected: | ||||||
|   const HloInstruction* hlo_instruction() const { return hlo_instruction_; } |   const HloInstruction* hlo_instruction() const { return hlo_instruction_; } | ||||||
| 
 | 
 | ||||||
|  |   absl::optional<int64> profile_index() const { return profile_index_; } | ||||||
|  | 
 | ||||||
|   const HloModuleConfig& GetModuleConfig() const { |   const HloModuleConfig& GetModuleConfig() const { | ||||||
|     return hlo_instruction()->GetModule()->config(); |     return hlo_instruction()->GetModule()->config(); | ||||||
|   } |   } | ||||||
| @ -146,8 +156,12 @@ class Thunk { | |||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   Kind kind_; |   Kind kind_; | ||||||
|  | 
 | ||||||
|  |   // Will be removed in the future, as Thunk is migrating away from the
 | ||||||
|  |   // monolithic HloInstruction.
 | ||||||
|   const HloInstruction* hlo_instruction_; |   const HloInstruction* hlo_instruction_; | ||||||
|   std::string name_; |   std::string name_; | ||||||
|  |   absl::optional<int64> profile_index_; | ||||||
|   string profile_annotation_; |   string profile_annotation_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -40,11 +40,11 @@ namespace gpu { | |||||||
| std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { | std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { | ||||||
|   const HloInstruction* operand = inst->operand(0); |   const HloInstruction* operand = inst->operand(0); | ||||||
|   return absl::make_unique<FftThunk>( |   return absl::make_unique<FftThunk>( | ||||||
|       inst->fft_type(), inst->fft_length(), |       context_->GetThunkInfo(inst), inst->fft_type(), inst->fft_length(), | ||||||
|       /*input_buffer=*/GetAllocationSlice(*operand), |       /*input_buffer=*/GetAllocationSlice(*operand), | ||||||
|       /*output_buffer=*/GetAllocationSlice(*inst), |       /*output_buffer=*/GetAllocationSlice(*inst), | ||||||
|       /*input_shape=*/operand->shape(), |       /*input_shape=*/operand->shape(), | ||||||
|       /*output_shape=*/inst->shape(), inst); |       /*output_shape=*/inst->shape()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk( | std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk( | ||||||
| @ -63,11 +63,11 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk( | |||||||
|                              : n * n * elem_size; |                              : n * n * elem_size; | ||||||
|   int64 b_batch_stride = m * n * elem_size; |   int64 b_batch_stride = m * n * elem_size; | ||||||
|   return absl::make_unique<TriangularSolveThunk>( |   return absl::make_unique<TriangularSolveThunk>( | ||||||
|       inst->triangular_solve_options(), |       context_->GetThunkInfo(inst), inst->triangular_solve_options(), | ||||||
|       /*a_input_buffer=*/GetAllocationSlice(*a), |       /*a_input_buffer=*/GetAllocationSlice(*a), | ||||||
|       /*b_input_buffer=*/GetAllocationSlice(*inst), |       /*b_input_buffer=*/GetAllocationSlice(*inst), | ||||||
|       inst->shape().element_type(), batch_size, m, n, a_batch_stride, |       inst->shape().element_type(), batch_size, m, n, a_batch_stride, | ||||||
|       b_batch_stride, inst); |       b_batch_stride); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk( | std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk( | ||||||
| @ -86,24 +86,27 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk( | |||||||
|     if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { |     if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { | ||||||
|       std::vector<std::unique_ptr<Thunk>> thunks; |       std::vector<std::unique_ptr<Thunk>> thunks; | ||||||
|       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |           Thunk::ThunkInfo(), | ||||||
|           /*source_buffer=*/GetAllocationSlice(*bias), |           /*source_buffer=*/GetAllocationSlice(*bias), | ||||||
|           /*destination_buffer=*/GetAllocationSlice(*inst), |           /*destination_buffer=*/GetAllocationSlice(*inst), | ||||||
|           /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr)); |           /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()))); | ||||||
|       thunks.push_back(absl::make_unique<GemmThunk>( |       thunks.push_back(absl::make_unique<GemmThunk>( | ||||||
|  |           context_->GetThunkInfo(inst), | ||||||
|           GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
 |           GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
 | ||||||
|           GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
 |           GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
 | ||||||
|           GetAllocationSlice(*inst),  // The output buffer.
 |           GetAllocationSlice(*inst),  // The output buffer.
 | ||||||
|           /*implements_whole_instruction=*/false, inst, |           /*implements_whole_instruction=*/false, std::move(gemm_config))); | ||||||
|           std::move(gemm_config))); |       return absl::make_unique<SequentialThunk>(context_->GetThunkInfo(inst), | ||||||
|       return absl::make_unique<SequentialThunk>(std::move(thunks), inst); |                                                 std::move(thunks)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   return absl::make_unique<GemmThunk>( |   return absl::make_unique<GemmThunk>( | ||||||
|  |       context_->GetThunkInfo(inst), | ||||||
|       GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
 |       GetAllocationSlice(*lhs),   // The buffer assigned to LHS.
 | ||||||
|       GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
 |       GetAllocationSlice(*rhs),   // The buffer assigned to RHS.
 | ||||||
|       GetAllocationSlice(*inst),  // The output buffer.
 |       GetAllocationSlice(*inst),  // The output buffer.
 | ||||||
|       /*implements_whole_instruction=*/true, inst, std::move(gemm_config)); |       /*implements_whole_instruction=*/true, std::move(gemm_config)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk( | std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk( | ||||||
| @ -115,7 +118,7 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk( | |||||||
|       [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { |       [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { | ||||||
|         *slice = GetAllocationSlice(*inst, index); |         *slice = GetAllocationSlice(*inst, index); | ||||||
|       }); |       }); | ||||||
|   return absl::make_unique<InfeedThunk>(slices, inst); |   return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst), slices); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk( | std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk( | ||||||
| @ -130,7 +133,8 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk( | |||||||
|       *slice = status_or_slice.ValueOrDie(); |       *slice = status_or_slice.ValueOrDie(); | ||||||
|     } |     } | ||||||
|   }); |   }); | ||||||
|   return absl::make_unique<OutfeedThunk>(std::move(slices), inst); |   return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst), | ||||||
|  |                                          std::move(slices)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | ||||||
| @ -152,6 +156,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
| 
 | 
 | ||||||
|     AddThunkToThunkSequence( |     AddThunkToThunkSequence( | ||||||
|         absl::make_unique<CudnnBatchNormForwardInferenceThunk>( |         absl::make_unique<CudnnBatchNormForwardInferenceThunk>( | ||||||
|  |             context_->GetThunkInfo(custom_call), | ||||||
|             /*operand=*/GetAllocationSlice(*custom_call->operand(0)), |             /*operand=*/GetAllocationSlice(*custom_call->operand(0)), | ||||||
|             /*scale=*/GetAllocationSlice(*custom_call->operand(1)), |             /*scale=*/GetAllocationSlice(*custom_call->operand(1)), | ||||||
|             /*offset=*/GetAllocationSlice(*custom_call->operand(2)), |             /*offset=*/GetAllocationSlice(*custom_call->operand(2)), | ||||||
| @ -159,8 +164,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|             /*variance=*/GetAllocationSlice(*custom_call->operand(4)), |             /*variance=*/GetAllocationSlice(*custom_call->operand(4)), | ||||||
|             /*epsilon=*/epsilon_value, |             /*epsilon=*/epsilon_value, | ||||||
|             /*feature_index=*/feature_index_value, |             /*feature_index=*/feature_index_value, | ||||||
|             /*output=*/GetAllocationSlice(*custom_call), |             /*output=*/GetAllocationSlice(*custom_call))); | ||||||
|             /*hlo=*/custom_call)); |  | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -181,6 +185,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|     auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); |     auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); | ||||||
|     AddThunkToThunkSequence( |     AddThunkToThunkSequence( | ||||||
|         absl::make_unique<CudnnBatchNormForwardTrainingThunk>( |         absl::make_unique<CudnnBatchNormForwardTrainingThunk>( | ||||||
|  |             context_->GetThunkInfo(custom_call), | ||||||
|             /*operand=*/GetAllocationSlice(*custom_call->operand(0)), |             /*operand=*/GetAllocationSlice(*custom_call->operand(0)), | ||||||
|             /*scale=*/GetAllocationSlice(*custom_call->operand(1)), |             /*scale=*/GetAllocationSlice(*custom_call->operand(1)), | ||||||
|             /*offset=*/GetAllocationSlice(*custom_call->operand(2)), |             /*offset=*/GetAllocationSlice(*custom_call->operand(2)), | ||||||
| @ -189,8 +194,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|             /*output_data=*/output_data, |             /*output_data=*/output_data, | ||||||
|             /*output_mean=*/output_mean, |             /*output_mean=*/output_mean, | ||||||
|             /*output_inv_stddev=*/output_inv_stddev, |             /*output_inv_stddev=*/output_inv_stddev, | ||||||
|             /*output_tuple=*/GetAllocationSlice(*custom_call), |             /*output_tuple=*/GetAllocationSlice(*custom_call))); | ||||||
|             /*hlo=*/custom_call)); |  | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -209,6 +213,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|     auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); |     auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); | ||||||
|     auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); |     auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); | ||||||
|     AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>( |     AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>( | ||||||
|  |         context_->GetThunkInfo(custom_call), | ||||||
|         /*operand=*/GetAllocationSlice(*custom_call->operand(0)), |         /*operand=*/GetAllocationSlice(*custom_call->operand(0)), | ||||||
|         /*scale=*/GetAllocationSlice(*custom_call->operand(1)), |         /*scale=*/GetAllocationSlice(*custom_call->operand(1)), | ||||||
|         /*mean=*/GetAllocationSlice(*custom_call->operand(2)), |         /*mean=*/GetAllocationSlice(*custom_call->operand(2)), | ||||||
| @ -219,8 +224,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|         /*output_grad_data=*/output_grad_data, |         /*output_grad_data=*/output_grad_data, | ||||||
|         /*output_grad_scale=*/output_grad_scale, |         /*output_grad_scale=*/output_grad_scale, | ||||||
|         /*output_grad_offset=*/output_grad_offset, |         /*output_grad_offset=*/output_grad_offset, | ||||||
|         /*output_tuple=*/GetAllocationSlice(*custom_call), |         /*output_tuple=*/GetAllocationSlice(*custom_call))); | ||||||
|         /*hlo=*/custom_call)); |  | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -235,7 +239,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|     auto scratch_slice = GetAllocationSlice(*custom_call, {1}); |     auto scratch_slice = GetAllocationSlice(*custom_call, {1}); | ||||||
| 
 | 
 | ||||||
|     AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>( |     AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>( | ||||||
|         Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices), |         context_->GetThunkInfo(custom_call), std::move(operand_slices), | ||||||
|         conv_result_slice, scratch_slice, tuple_result_slice)); |         conv_result_slice, scratch_slice, tuple_result_slice)); | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| @ -269,22 +273,23 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
| 
 | 
 | ||||||
|     if (operand_buffer != a_buffer) { |     if (operand_buffer != a_buffer) { | ||||||
|       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |           context_->GetThunkInfo(custom_call), | ||||||
|           /*source_address=*/operand_buffer, |           /*source_address=*/operand_buffer, | ||||||
|           /*destination_buffer=*/a_buffer, |           /*destination_buffer=*/a_buffer, | ||||||
|           /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call)); |           /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     thunks.push_back(absl::make_unique<CholeskyThunk>( |     thunks.push_back(absl::make_unique<CholeskyThunk>( | ||||||
|         options, a_buffer, workspace_buffer, info_buffer, |         context_->GetThunkInfo(custom_call), options, a_buffer, | ||||||
|         custom_call->operand(0)->shape().element_type(), batch_size, n, |         workspace_buffer, info_buffer, | ||||||
|         custom_call)); |         custom_call->operand(0)->shape().element_type(), batch_size, n)); | ||||||
| 
 | 
 | ||||||
|     // Elide the sequential thunk if there's no copy.
 |     // Elide the sequential thunk if there's no copy.
 | ||||||
|     if (thunks.size() == 1) { |     if (thunks.size() == 1) { | ||||||
|       AddThunkToThunkSequence(std::move(thunks[0])); |       AddThunkToThunkSequence(std::move(thunks[0])); | ||||||
|     } else { |     } else { | ||||||
|       AddThunkToThunkSequence( |       AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( | ||||||
|           absl::make_unique<SequentialThunk>(std::move(thunks), custom_call)); |           context_->GetThunkInfo(custom_call), std::move(thunks))); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| @ -311,8 +316,9 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { | |||||||
|     ShapeTree<BufferAllocation::Slice> result_slices = |     ShapeTree<BufferAllocation::Slice> result_slices = | ||||||
|         get_slices_for_instr(custom_call); |         get_slices_for_instr(custom_call); | ||||||
|     AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>( |     AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>( | ||||||
|         call_target, std::move(operand_slices), std::move(result_slices), |         context_->GetThunkInfo(custom_call), call_target, | ||||||
|         Cast<HloCustomCallInstruction>(custom_call)->opaque(), custom_call)); |         std::move(operand_slices), std::move(result_slices), | ||||||
|  |         Cast<HloCustomCallInstruction>(custom_call)->opaque())); | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
| @ -347,9 +353,10 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { | |||||||
|   auto destination_buffer = GetAllocationSlice(*hlo); |   auto destination_buffer = GetAllocationSlice(*hlo); | ||||||
|   if (operand_buffer != destination_buffer) { |   if (operand_buffer != destination_buffer) { | ||||||
|     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( |     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( | ||||||
|  |         context_->GetThunkInfo(hlo), | ||||||
|         /*source_address=*/operand_buffer, |         /*source_address=*/operand_buffer, | ||||||
|         /*destination_buffer=*/destination_buffer, |         /*destination_buffer=*/destination_buffer, | ||||||
|         /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); |         /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()))); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   thunks.push_back(BuildTriangularSolveThunk(hlo)); |   thunks.push_back(BuildTriangularSolveThunk(hlo)); | ||||||
| @ -358,8 +365,8 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { | |||||||
|   if (thunks.size() == 1) { |   if (thunks.size() == 1) { | ||||||
|     AddThunkToThunkSequence(std::move(thunks[0])); |     AddThunkToThunkSequence(std::move(thunks[0])); | ||||||
|   } else { |   } else { | ||||||
|     AddThunkToThunkSequence( |     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>( | ||||||
|         absl::make_unique<SequentialThunk>(std::move(thunks), hlo)); |         context_->GetThunkInfo(hlo), std::move(thunks))); | ||||||
|   } |   } | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| @ -374,5 +381,12 @@ Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) { | |||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo( | ||||||
|  |     const HloInstruction* hlo) const { | ||||||
|  |   CHECK(hlo); | ||||||
|  |   Thunk::ThunkInfo info; | ||||||
|  |   info.hlo_instruction = hlo; | ||||||
|  |   return info; | ||||||
|  | } | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
| }  // namespace xla
 | }  // namespace xla
 | ||||||
|  | |||||||
| @ -36,6 +36,7 @@ class ThunkEmitter { | |||||||
|         const HloInstruction& hlo, const ShapeIndex& index) const = 0; |         const HloInstruction& hlo, const ShapeIndex& index) const = 0; | ||||||
|     virtual int64 ByteSizeOf(const Shape& shape) const = 0; |     virtual int64 ByteSizeOf(const Shape& shape) const = 0; | ||||||
|     virtual absl::string_view platform_name() const = 0; |     virtual absl::string_view platform_name() const = 0; | ||||||
|  |     virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const; | ||||||
| 
 | 
 | ||||||
|     virtual ~EmissionContext() = default; |     virtual ~EmissionContext() = default; | ||||||
|   }; |   }; | ||||||
|  | |||||||
| @ -32,12 +32,12 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| TriangularSolveThunk::TriangularSolveThunk( | TriangularSolveThunk::TriangularSolveThunk( | ||||||
|     const TriangularSolveOptions& options, |     ThunkInfo thunk_info, const TriangularSolveOptions& options, | ||||||
|     const BufferAllocation::Slice& a_buffer, |     const BufferAllocation::Slice& a_buffer, | ||||||
|     const BufferAllocation::Slice& b_buffer, PrimitiveType type, |     const BufferAllocation::Slice& b_buffer, PrimitiveType type, | ||||||
|     int64 batch_size, int64 m, int64 n, int64 a_batch_stride, |     int64 batch_size, int64 m, int64 n, int64 a_batch_stride, | ||||||
|     int64 b_batch_stride, const HloInstruction* hlo) |     int64 b_batch_stride) | ||||||
|     : Thunk(Kind::kTriangularSolve, hlo), |     : Thunk(Kind::kTriangularSolve, thunk_info), | ||||||
|       uplo_(options.lower() ? se::blas::UpperLower::kLower |       uplo_(options.lower() ? se::blas::UpperLower::kLower | ||||||
|                             : se::blas::UpperLower::kUpper), |                             : se::blas::UpperLower::kUpper), | ||||||
|       side_(options.left_side() ? se::blas::Side::kLeft |       side_(options.left_side() ? se::blas::Side::kLeft | ||||||
|  | |||||||
| @ -38,12 +38,12 @@ namespace gpu { | |||||||
| // Thread-compatible.
 | // Thread-compatible.
 | ||||||
| class TriangularSolveThunk : public Thunk { | class TriangularSolveThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   TriangularSolveThunk(const TriangularSolveOptions& options, |   TriangularSolveThunk(ThunkInfo thunk_info, | ||||||
|  |                        const TriangularSolveOptions& options, | ||||||
|                        const BufferAllocation::Slice& a_buffer, |                        const BufferAllocation::Slice& a_buffer, | ||||||
|                        const BufferAllocation::Slice& b_buffer, |                        const BufferAllocation::Slice& b_buffer, | ||||||
|                        PrimitiveType type, int64 batch_size, int64 m, int64 n, |                        PrimitiveType type, int64 batch_size, int64 m, int64 n, | ||||||
|                        int64 a_batch_stride, int64 b_batch_stride, |                        int64 a_batch_stride, int64 b_batch_stride); | ||||||
|                        const HloInstruction* hlo); |  | ||||||
| 
 | 
 | ||||||
|   TriangularSolveThunk(const TriangularSolveThunk&) = delete; |   TriangularSolveThunk(const TriangularSolveThunk&) = delete; | ||||||
|   TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; |   TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; | ||||||
|  | |||||||
| @ -34,7 +34,7 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = |   auto op_profiler = | ||||||
|       params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); |       params.profiler->MakeScopedInstructionProfiler(profile_index()); | ||||||
|   SafeH2DMemcpy(se::DeviceMemory<void*>( |   SafeH2DMemcpy(se::DeviceMemory<void*>( | ||||||
|                     buffer_allocations.GetDeviceAddress(dest_buffer_)), |                     buffer_allocations.GetDeviceAddress(dest_buffer_)), | ||||||
|                 std::move(tuple_data), n, &stream, |                 std::move(tuple_data), n, &stream, | ||||||
|  | |||||||
| @ -34,10 +34,10 @@ namespace gpu { | |||||||
| // issue (b/31336476).
 | // issue (b/31336476).
 | ||||||
| class TupleThunk : public Thunk { | class TupleThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   TupleThunk(absl::Span<const BufferAllocation::Slice> tuple_element_buffers, |   TupleThunk(ThunkInfo thunk_info, | ||||||
|              const BufferAllocation::Slice& dest_buffer, |              absl::Span<const BufferAllocation::Slice> tuple_element_buffers, | ||||||
|              const HloInstruction* hlo_instruction) |              const BufferAllocation::Slice& dest_buffer) | ||||||
|       : Thunk(Kind::kTuple, hlo_instruction), |       : Thunk(Kind::kTuple, thunk_info), | ||||||
|         tuple_element_buffers_(tuple_element_buffers.begin(), |         tuple_element_buffers_(tuple_element_buffers.begin(), | ||||||
|                                tuple_element_buffers.end()), |                                tuple_element_buffers.end()), | ||||||
|         dest_buffer_(dest_buffer) {} |         dest_buffer_(dest_buffer) {} | ||||||
|  | |||||||
| @ -24,20 +24,20 @@ namespace xla { | |||||||
| namespace gpu { | namespace gpu { | ||||||
| 
 | 
 | ||||||
| WhileThunk::WhileThunk( | WhileThunk::WhileThunk( | ||||||
|  |     ThunkInfo thunk_info, | ||||||
|     const BufferAllocation::Slice& condition_result_buffer_index, |     const BufferAllocation::Slice& condition_result_buffer_index, | ||||||
|     std::unique_ptr<ThunkSequence> condition_thunk_sequence, |     std::unique_ptr<ThunkSequence> condition_thunk_sequence, | ||||||
|     std::unique_ptr<ThunkSequence> body_thunk_sequence, |     std::unique_ptr<ThunkSequence> body_thunk_sequence) | ||||||
|     const HloInstruction* hlo) |     : Thunk(Kind::kWhile, thunk_info), | ||||||
|     : Thunk(Kind::kWhile, hlo), |  | ||||||
|       condition_result_buffer_index_(condition_result_buffer_index), |       condition_result_buffer_index_(condition_result_buffer_index), | ||||||
|       // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
 |       // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
 | ||||||
|       // and body_thunk_sequence_ constructors because these SequentialThunks
 |       // and body_thunk_sequence_ constructors because these SequentialThunks
 | ||||||
|       // are logically "part of" this WhileThunk, and shouldn't be profiled
 |       // are logically "part of" this WhileThunk, and shouldn't be profiled
 | ||||||
|       // separately from it.
 |       // separately from it.
 | ||||||
|       condition_thunk_sequence_(absl::make_unique<SequentialThunk>( |       condition_thunk_sequence_(absl::make_unique<SequentialThunk>( | ||||||
|           std::move(*condition_thunk_sequence), nullptr)), |           ThunkInfo(), std::move(*condition_thunk_sequence))), | ||||||
|       body_thunk_sequence_(absl::make_unique<SequentialThunk>( |       body_thunk_sequence_(absl::make_unique<SequentialThunk>( | ||||||
|           std::move(*body_thunk_sequence), nullptr)) {} |           ThunkInfo(), std::move(*body_thunk_sequence))) {} | ||||||
| 
 | 
 | ||||||
| void WhileThunk::ComputeAnnotations() { | void WhileThunk::ComputeAnnotations() { | ||||||
|   Thunk::ComputeAnnotations(); |   Thunk::ComputeAnnotations(); | ||||||
| @ -61,7 +61,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { | |||||||
|       params.buffer_allocations->GetDeviceAddress( |       params.buffer_allocations->GetDeviceAddress( | ||||||
|           condition_result_buffer_index_); |           condition_result_buffer_index_); | ||||||
| 
 | 
 | ||||||
|   auto op_profiler = profiler.MakeScopedInstructionProfiler(hlo_instruction()); |   auto op_profiler = profiler.MakeScopedInstructionProfiler(profile_index()); | ||||||
|   while (true) { |   while (true) { | ||||||
|     // Invoke thunk sequence for while 'condition' computation.
 |     // Invoke thunk sequence for while 'condition' computation.
 | ||||||
|     profiler.StartHloComputation(); |     profiler.StartHloComputation(); | ||||||
|  | |||||||
| @ -39,10 +39,10 @@ namespace gpu { | |||||||
| class WhileThunk : public Thunk { | class WhileThunk : public Thunk { | ||||||
|  public: |  public: | ||||||
|   // Constructs a WhileThunk to compute while instruction 'hlo'.
 |   // Constructs a WhileThunk to compute while instruction 'hlo'.
 | ||||||
|   WhileThunk(const BufferAllocation::Slice& condition_result_buffer_index, |   WhileThunk(ThunkInfo thunk_info, | ||||||
|  |              const BufferAllocation::Slice& condition_result_buffer_index, | ||||||
|              std::unique_ptr<ThunkSequence> condition_thunk_sequence, |              std::unique_ptr<ThunkSequence> condition_thunk_sequence, | ||||||
|              std::unique_ptr<ThunkSequence> body_thunk_sequence, |              std::unique_ptr<ThunkSequence> body_thunk_sequence); | ||||||
|              const HloInstruction* hlo); |  | ||||||
|   WhileThunk(const WhileThunk&) = delete; |   WhileThunk(const WhileThunk&) = delete; | ||||||
|   WhileThunk& operator=(const WhileThunk&) = delete; |   WhileThunk& operator=(const WhileThunk&) = delete; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -133,8 +133,12 @@ HloExecutionProfile::HloExecutionProfile( | |||||||
| 
 | 
 | ||||||
| void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, | void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, | ||||||
|                                            uint64 cycles_taken) { |                                            uint64 cycles_taken) { | ||||||
|   profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] = |   SetCyclesTakenBy(hlo_profile_index_map_.GetProfileIndexFor(*hlo), | ||||||
|       cycles_taken; |                    cycles_taken); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void HloExecutionProfile::SetCyclesTakenBy(size_t index, uint64 cycles_taken) { | ||||||
|  |   profile_counters_[index] = cycles_taken; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { | uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { | ||||||
|  | |||||||
| @ -114,6 +114,9 @@ class HloExecutionProfile { | |||||||
|   // Record how many cycles this HLO took to execute.
 |   // Record how many cycles this HLO took to execute.
 | ||||||
|   void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); |   void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); | ||||||
| 
 | 
 | ||||||
|  |   // Record how many cycles this HLO took to execute.
 | ||||||
|  |   void SetCyclesTakenBy(size_t index, uint64 cycles_taken); | ||||||
|  | 
 | ||||||
|   // Returns how many cycles this HLO took to execute.  Profiling information
 |   // Returns how many cycles this HLO took to execute.  Profiling information
 | ||||||
|   // may not be available for some instructions in which case zero is returned.
 |   // may not be available for some instructions in which case zero is returned.
 | ||||||
|   uint64 GetCyclesTakenBy(const HloInstruction& hlo) const; |   uint64 GetCyclesTakenBy(const HloInstruction& hlo) const; | ||||||
|  | |||||||
| @ -88,6 +88,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, | |||||||
|       const HloInstruction& hlo, const ShapeIndex& index) const override; |       const HloInstruction& hlo, const ShapeIndex& index) const override; | ||||||
|   int64 ByteSizeOf(const Shape& shape) const override; |   int64 ByteSizeOf(const Shape& shape) const override; | ||||||
|   absl::string_view platform_name() const override; |   absl::string_view platform_name() const override; | ||||||
|  | 
 | ||||||
|   mlir::Location getLocation(const HloInstruction* instr) const; |   mlir::Location getLocation(const HloInstruction* instr) const; | ||||||
| 
 | 
 | ||||||
|   xla::mlir_gpu::EmissionContext* emission_context_; |   xla::mlir_gpu::EmissionContext* emission_context_; | ||||||
|  | |||||||
| @ -436,8 +436,10 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk( | |||||||
|       kernel, operand_to_value_map, ordered_operands, assignment, buffers)); |       kernel, operand_to_value_map, ordered_operands, assignment, buffers)); | ||||||
| 
 | 
 | ||||||
|   // Finally, create the thunk and set the launch dimensions.
 |   // Finally, create the thunk and set the launch dimensions.
 | ||||||
|   auto thunk = absl::make_unique<gpu::KernelThunk>( |   gpu::Thunk::ThunkInfo info; | ||||||
|       buffers, kernel.getName().str(), instr); |   info.hlo_instruction = instr; | ||||||
|  |   auto thunk = absl::make_unique<gpu::KernelThunk>(info, buffers, | ||||||
|  |                                                    kernel.getName().str()); | ||||||
| 
 | 
 | ||||||
|   // Set launch bounds.
 |   // Set launch bounds.
 | ||||||
|   mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); |   mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user