From 5a08d776c6292beef5a4e93c3d908c8fe4a91a66 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 9 Dec 2020 17:19:54 -0800 Subject: [PATCH] [XLA:GPU] Eliminate tuple population from batch norm thunks - These tuples on the GPU side should not be directly used by anyone since XLA should folded the GetTupleElement into which these tuple feeds. PiperOrigin-RevId: 346672532 Change-Id: Ia5980e14ddd157d84fe60cb40dd4ebc2f5a77c9e --- .../xla/service/gpu/cudnn_batchnorm_thunk.cc | 33 +++---------------- .../xla/service/gpu/cudnn_batchnorm_thunk.h | 25 ++++++-------- .../compiler/xla/service/gpu/thunk_emitter.cc | 6 ++-- 3 files changed, 16 insertions(+), 48 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index dae490e0d18..8d70bb2f424 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, - const BufferAllocation::Slice& output_inv_stddev, - const BufferAllocation::Slice& output_tuple) + const BufferAllocation::Slice& output_inv_stddev) : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), config_(std::move(config)), operand_(operand), @@ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( offset_(offset), output_data_(output_data), output_mean_(output_mean), - output_inv_stddev_(output_inv_stddev), - output_tuple_(output_tuple) {} + output_inv_stddev_(output_inv_stddev) {} Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), &stream)); - // Write the output tuple. - const int kNumOutputs = 3; - auto ptrs = absl::make_unique(kNumOutputs); - ptrs[0] = output_data.opaque(); - ptrs[1] = output_mean.opaque(); - ptrs[2] = output_inv_stddev.opaque(); - se::DeviceMemory tuple_addr( - buffer_allocations.GetDeviceAddress(output_tuple_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream, - params.deferred_host_callbacks); if (!stream.ok()) { return InternalError("BatchNormalizationTraining call failed."); } @@ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( const BufferAllocation::Slice& grad_output, const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, - const BufferAllocation::Slice& output_grad_offset, - const BufferAllocation::Slice& output_tuple) + const BufferAllocation::Slice& output_grad_offset) : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), config_(std::move(config)), operand_(operand), @@ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( grad_output_(grad_output), output_grad_data_(output_grad_data), output_grad_scale_(output_grad_scale), - output_grad_offset_(output_grad_offset), - output_tuple_(output_tuple) {} + output_grad_offset_(output_grad_offset) {} Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), stream)); - // Write the output tuple. - const int kNumOutputs = 3; - auto ptrs = absl::make_unique(kNumOutputs); - ptrs[0] = output_grad_data.opaque(); - ptrs[1] = output_grad_scale.opaque(); - ptrs[2] = output_grad_offset.opaque(); - se::DeviceMemory tuple_addr( - buffer_allocations.GetDeviceAddress(output_tuple_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream, - params.deferred_host_callbacks); - if (!stream->ok()) { return InternalError("BatchNormalizationBackward call failed."); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index d45e284ea2c..48c46a6bc08 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, - const BufferAllocation::Slice& output_inv_stddev, - const BufferAllocation::Slice& output_tuple); + const BufferAllocation::Slice& output_inv_stddev); CudnnBatchNormForwardTrainingThunk( const CudnnBatchNormForwardTrainingThunk&) = delete; @@ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { BufferAllocation::Slice output_data_; BufferAllocation::Slice output_mean_; BufferAllocation::Slice output_inv_stddev_; - BufferAllocation::Slice output_tuple_; }; class CudnnBatchNormBackwardThunk : public Thunk { public: - CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, - CudnnBatchNormConfig&& config, - const BufferAllocation::Slice& operand, - const BufferAllocation::Slice& scale, - const BufferAllocation::Slice& mean, - const BufferAllocation::Slice& inv_stddev, - const BufferAllocation::Slice& grad_output, - const BufferAllocation::Slice& output_grad_data, - const BufferAllocation::Slice& output_grad_scale, - const BufferAllocation::Slice& output_grad_offset, - const BufferAllocation::Slice& output_tuple); + CudnnBatchNormBackwardThunk( + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& inv_stddev, + const BufferAllocation::Slice& grad_output, + const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& output_grad_scale, + const BufferAllocation::Slice& output_grad_offset); CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = @@ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk { BufferAllocation::Slice output_grad_data_; BufferAllocation::Slice output_grad_scale_; BufferAllocation::Slice output_grad_offset_; - BufferAllocation::Slice output_tuple_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index d401f1d894b..058aad76777 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*offset=*/GetAllocationSlice(*custom_call->operand(2)), /*output_data=*/output_data, /*output_mean=*/output_mean, - /*output_inv_stddev=*/output_inv_stddev, - /*output_tuple=*/GetAllocationSlice(*custom_call))); + /*output_inv_stddev=*/output_inv_stddev)); return Status::OK(); } @@ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), /*output_grad_data=*/output_grad_data, /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call))); + /*output_grad_offset=*/output_grad_offset)); return Status::OK(); }