[XLA:GPU] Eliminate tuple population from batch norm thunks

- These tuples on the GPU side should not be directly used by anyone since XLA should
   folded the GetTupleElement into which these tuple feeds.

PiperOrigin-RevId: 346672532
Change-Id: Ia5980e14ddd157d84fe60cb40dd4ebc2f5a77c9e
This commit is contained in:
Rahul Joshi 2020-12-09 17:19:54 -08:00 committed by TensorFlower Gardener
parent 4d1c107bef
commit 5a08d776c6
3 changed files with 16 additions and 48 deletions

View File

@ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple)
const BufferAllocation::Slice& output_inv_stddev)
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
config_(std::move(config)),
operand_(operand),
@ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
offset_(offset),
output_data_(output_data),
output_mean_(output_mean),
output_inv_stddev_(output_inv_stddev),
output_tuple_(output_tuple) {}
output_inv_stddev_(output_inv_stddev) {}
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
const ExecuteParams& params) {
@ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
&stream));
// Write the output tuple.
const int kNumOutputs = 3;
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
ptrs[0] = output_data.opaque();
ptrs[1] = output_mean.opaque();
ptrs[2] = output_inv_stddev.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
params.deferred_host_callbacks);
if (!stream.ok()) {
return InternalError("BatchNormalizationTraining call failed.");
}
@ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
const BufferAllocation::Slice& grad_output,
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple)
const BufferAllocation::Slice& output_grad_offset)
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
config_(std::move(config)),
operand_(operand),
@ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
grad_output_(grad_output),
output_grad_data_(output_grad_data),
output_grad_scale_(output_grad_scale),
output_grad_offset_(output_grad_offset),
output_tuple_(output_tuple) {}
output_grad_offset_(output_grad_offset) {}
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
const ExecuteParams& params) {
@ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
stream));
// Write the output tuple.
const int kNumOutputs = 3;
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
ptrs[0] = output_grad_data.opaque();
ptrs[1] = output_grad_scale.opaque();
ptrs[2] = output_grad_offset.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream,
params.deferred_host_callbacks);
if (!stream->ok()) {
return InternalError("BatchNormalizationBackward call failed.");
}

View File

@ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
const BufferAllocation::Slice& offset,
const BufferAllocation::Slice& output_data,
const BufferAllocation::Slice& output_mean,
const BufferAllocation::Slice& output_inv_stddev,
const BufferAllocation::Slice& output_tuple);
const BufferAllocation::Slice& output_inv_stddev);
CudnnBatchNormForwardTrainingThunk(
const CudnnBatchNormForwardTrainingThunk&) = delete;
@ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
BufferAllocation::Slice output_data_;
BufferAllocation::Slice output_mean_;
BufferAllocation::Slice output_inv_stddev_;
BufferAllocation::Slice output_tuple_;
};
class CudnnBatchNormBackwardThunk : public Thunk {
public:
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale,
const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& grad_output,
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset,
const BufferAllocation::Slice& output_tuple);
CudnnBatchNormBackwardThunk(
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
const BufferAllocation::Slice& operand,
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
const BufferAllocation::Slice& inv_stddev,
const BufferAllocation::Slice& grad_output,
const BufferAllocation::Slice& output_grad_data,
const BufferAllocation::Slice& output_grad_scale,
const BufferAllocation::Slice& output_grad_offset);
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
@ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk {
BufferAllocation::Slice output_grad_data_;
BufferAllocation::Slice output_grad_scale_;
BufferAllocation::Slice output_grad_offset_;
BufferAllocation::Slice output_tuple_;
};
} // namespace gpu

View File

@ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
/*output_data=*/output_data,
/*output_mean=*/output_mean,
/*output_inv_stddev=*/output_inv_stddev,
/*output_tuple=*/GetAllocationSlice(*custom_call)));
/*output_inv_stddev=*/output_inv_stddev));
return Status::OK();
}
@ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
/*output_grad_data=*/output_grad_data,
/*output_grad_scale=*/output_grad_scale,
/*output_grad_offset=*/output_grad_offset,
/*output_tuple=*/GetAllocationSlice(*custom_call)));
/*output_grad_offset=*/output_grad_offset));
return Status::OK();
}