[XLA:GPU] Eliminate tuple population from batch norm thunks

- These tuples on the GPU side should not be directly used by anyone since XLA should
   folded the GetTupleElement into which these tuple feeds.

PiperOrigin-RevId: 346672532
Change-Id: Ia5980e14ddd157d84fe60cb40dd4ebc2f5a77c9e
This commit is contained in:
Rahul Joshi 2020-12-09 17:19:54 -08:00 committed by TensorFlower Gardener
parent 4d1c107bef
commit 5a08d776c6
3 changed files with 16 additions and 48 deletions

View File

@ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev, const BufferAllocation::Slice& output_inv_stddev)
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
config_(std::move(config)), config_(std::move(config)),
operand_(operand), operand_(operand),
@ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
offset_(offset), offset_(offset),
output_data_(output_data), output_data_(output_data),
output_mean_(output_mean), output_mean_(output_mean),
output_inv_stddev_(output_inv_stddev), output_inv_stddev_(output_inv_stddev) {}
output_tuple_(output_tuple) {}
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
const ExecuteParams& params) { const ExecuteParams& params) {
@ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)), se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
&stream)); &stream));
// Write the output tuple.
const int kNumOutputs = 3;
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
ptrs[0] = output_data.opaque();
ptrs[1] = output_mean.opaque();
ptrs[2] = output_inv_stddev.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
params.deferred_host_callbacks);
if (!stream.ok()) { if (!stream.ok()) {
return InternalError("BatchNormalizationTraining call failed."); return InternalError("BatchNormalizationTraining call failed.");
} }
@ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
const BufferAllocation::Slice& grad_output, const BufferAllocation::Slice& grad_output,
const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset, const BufferAllocation::Slice& output_grad_offset)
const BufferAllocation::Slice& output_tuple)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
config_(std::move(config)), config_(std::move(config)),
operand_(operand), operand_(operand),
@ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
grad_output_(grad_output), grad_output_(grad_output),
output_grad_data_(output_grad_data), output_grad_data_(output_grad_data),
output_grad_scale_(output_grad_scale), output_grad_scale_(output_grad_scale),
output_grad_offset_(output_grad_offset), output_grad_offset_(output_grad_offset) {}
output_tuple_(output_tuple) {}
Status CudnnBatchNormBackwardThunk::ExecuteOnStream( Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
const ExecuteParams& params) { const ExecuteParams& params) {
@ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)), se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
stream)); stream));
// Write the output tuple.
const int kNumOutputs = 3;
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
ptrs[0] = output_grad_data.opaque();
ptrs[1] = output_grad_scale.opaque();
ptrs[2] = output_grad_offset.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream,
params.deferred_host_callbacks);
if (!stream->ok()) { if (!stream->ok()) {
return InternalError("BatchNormalizationBackward call failed."); return InternalError("BatchNormalizationBackward call failed.");
} }

View File

@ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
const BufferAllocation::Slice& offset, const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev, const BufferAllocation::Slice& output_inv_stddev);
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormForwardTrainingThunk( CudnnBatchNormForwardTrainingThunk(
const CudnnBatchNormForwardTrainingThunk&) = delete; const CudnnBatchNormForwardTrainingThunk&) = delete;
@ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
BufferAllocation::Slice output_data_; BufferAllocation::Slice output_data_;
BufferAllocation::Slice output_mean_; BufferAllocation::Slice output_mean_;
BufferAllocation::Slice output_inv_stddev_; BufferAllocation::Slice output_inv_stddev_;
BufferAllocation::Slice output_tuple_;
}; };
class CudnnBatchNormBackwardThunk : public Thunk { class CudnnBatchNormBackwardThunk : public Thunk {
public: public:
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, CudnnBatchNormBackwardThunk(
CudnnBatchNormConfig&& config, ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand, const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& inv_stddev, const BufferAllocation::Slice& grad_output,
const BufferAllocation::Slice& grad_output, const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset);
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
@ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk {
BufferAllocation::Slice output_grad_data_; BufferAllocation::Slice output_grad_data_;
BufferAllocation::Slice output_grad_scale_; BufferAllocation::Slice output_grad_scale_;
BufferAllocation::Slice output_grad_offset_; BufferAllocation::Slice output_grad_offset_;
BufferAllocation::Slice output_tuple_;
}; };
} // namespace gpu } // namespace gpu

View File

@ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*offset=*/GetAllocationSlice(*custom_call->operand(2)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)),
/*output_data=*/output_data, /*output_data=*/output_data,
/*output_mean=*/output_mean, /*output_mean=*/output_mean,
/*output_inv_stddev=*/output_inv_stddev, /*output_inv_stddev=*/output_inv_stddev));
/*output_tuple=*/GetAllocationSlice(*custom_call)));
return Status::OK(); return Status::OK();
} }
@ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
/*output_grad_data=*/output_grad_data, /*output_grad_data=*/output_grad_data,
/*output_grad_scale=*/output_grad_scale, /*output_grad_scale=*/output_grad_scale,
/*output_grad_offset=*/output_grad_offset, /*output_grad_offset=*/output_grad_offset));
/*output_tuple=*/GetAllocationSlice(*custom_call)));
return Status::OK(); return Status::OK();
} }