diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index dae490e0d18..8d70bb2f424 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, - const BufferAllocation::Slice& output_inv_stddev, - const BufferAllocation::Slice& output_tuple) + const BufferAllocation::Slice& output_inv_stddev) : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), config_(std::move(config)), operand_(operand), @@ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( offset_(offset), output_data_(output_data), output_mean_(output_mean), - output_inv_stddev_(output_inv_stddev), - output_tuple_(output_tuple) {} + output_inv_stddev_(output_inv_stddev) {} Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), &stream)); - // Write the output tuple. - const int kNumOutputs = 3; - auto ptrs = absl::make_unique(kNumOutputs); - ptrs[0] = output_data.opaque(); - ptrs[1] = output_mean.opaque(); - ptrs[2] = output_inv_stddev.opaque(); - se::DeviceMemory tuple_addr( - buffer_allocations.GetDeviceAddress(output_tuple_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream, - params.deferred_host_callbacks); if (!stream.ok()) { return InternalError("BatchNormalizationTraining call failed."); } @@ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( const BufferAllocation::Slice& grad_output, const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, - const BufferAllocation::Slice& output_grad_offset, - const BufferAllocation::Slice& output_tuple) + const BufferAllocation::Slice& output_grad_offset) : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), config_(std::move(config)), operand_(operand), @@ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( grad_output_(grad_output), output_grad_data_(output_grad_data), output_grad_scale_(output_grad_scale), - output_grad_offset_(output_grad_offset), - output_tuple_(output_tuple) {} + output_grad_offset_(output_grad_offset) {} Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), stream)); - // Write the output tuple. - const int kNumOutputs = 3; - auto ptrs = absl::make_unique(kNumOutputs); - ptrs[0] = output_grad_data.opaque(); - ptrs[1] = output_grad_scale.opaque(); - ptrs[2] = output_grad_offset.opaque(); - se::DeviceMemory tuple_addr( - buffer_allocations.GetDeviceAddress(output_tuple_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream, - params.deferred_host_callbacks); - if (!stream->ok()) { return InternalError("BatchNormalizationBackward call failed."); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index d45e284ea2c..48c46a6bc08 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, - const BufferAllocation::Slice& output_inv_stddev, - const BufferAllocation::Slice& output_tuple); + const BufferAllocation::Slice& output_inv_stddev); CudnnBatchNormForwardTrainingThunk( const CudnnBatchNormForwardTrainingThunk&) = delete; @@ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { BufferAllocation::Slice output_data_; BufferAllocation::Slice output_mean_; BufferAllocation::Slice output_inv_stddev_; - BufferAllocation::Slice output_tuple_; }; class CudnnBatchNormBackwardThunk : public Thunk { public: - CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, - CudnnBatchNormConfig&& config, - const BufferAllocation::Slice& operand, - const BufferAllocation::Slice& scale, - const BufferAllocation::Slice& mean, - const BufferAllocation::Slice& inv_stddev, - const BufferAllocation::Slice& grad_output, - const BufferAllocation::Slice& output_grad_data, - const BufferAllocation::Slice& output_grad_scale, - const BufferAllocation::Slice& output_grad_offset, - const BufferAllocation::Slice& output_tuple); + CudnnBatchNormBackwardThunk( + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& inv_stddev, + const BufferAllocation::Slice& grad_output, + const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& output_grad_scale, + const BufferAllocation::Slice& output_grad_offset); CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = @@ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk { BufferAllocation::Slice output_grad_data_; BufferAllocation::Slice output_grad_scale_; BufferAllocation::Slice output_grad_offset_; - BufferAllocation::Slice output_tuple_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index d401f1d894b..058aad76777 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*offset=*/GetAllocationSlice(*custom_call->operand(2)), /*output_data=*/output_data, /*output_mean=*/output_mean, - /*output_inv_stddev=*/output_inv_stddev, - /*output_tuple=*/GetAllocationSlice(*custom_call))); + /*output_inv_stddev=*/output_inv_stddev)); return Status::OK(); } @@ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), /*output_grad_data=*/output_grad_data, /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call))); + /*output_grad_offset=*/output_grad_offset)); return Status::OK(); }