[XLA:GPU] Eliminate tuple population from batch norm thunks
- These tuples on the GPU side should not be directly used by anyone since XLA should folded the GetTupleElement into which these tuple feeds. PiperOrigin-RevId: 346672532 Change-Id: Ia5980e14ddd157d84fe60cb40dd4ebc2f5a77c9e
This commit is contained in:
parent
4d1c107bef
commit
5a08d776c6
@ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
||||
const BufferAllocation::Slice& output_data,
|
||||
const BufferAllocation::Slice& output_mean,
|
||||
const BufferAllocation::Slice& output_inv_stddev,
|
||||
const BufferAllocation::Slice& output_tuple)
|
||||
const BufferAllocation::Slice& output_inv_stddev)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
|
||||
config_(std::move(config)),
|
||||
operand_(operand),
|
||||
@ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
||||
offset_(offset),
|
||||
output_data_(output_data),
|
||||
output_mean_(output_mean),
|
||||
output_inv_stddev_(output_inv_stddev),
|
||||
output_tuple_(output_tuple) {}
|
||||
output_inv_stddev_(output_inv_stddev) {}
|
||||
|
||||
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
const ExecuteParams& params) {
|
||||
@ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
|
||||
&stream));
|
||||
|
||||
// Write the output tuple.
|
||||
const int kNumOutputs = 3;
|
||||
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
|
||||
ptrs[0] = output_data.opaque();
|
||||
ptrs[1] = output_mean.opaque();
|
||||
ptrs[2] = output_inv_stddev.opaque();
|
||||
se::DeviceMemory<void*> tuple_addr(
|
||||
buffer_allocations.GetDeviceAddress(output_tuple_));
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
|
||||
params.deferred_host_callbacks);
|
||||
if (!stream.ok()) {
|
||||
return InternalError("BatchNormalizationTraining call failed.");
|
||||
}
|
||||
@ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
const BufferAllocation::Slice& grad_output,
|
||||
const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset,
|
||||
const BufferAllocation::Slice& output_tuple)
|
||||
const BufferAllocation::Slice& output_grad_offset)
|
||||
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
|
||||
config_(std::move(config)),
|
||||
operand_(operand),
|
||||
@ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
||||
grad_output_(grad_output),
|
||||
output_grad_data_(output_grad_data),
|
||||
output_grad_scale_(output_grad_scale),
|
||||
output_grad_offset_(output_grad_offset),
|
||||
output_tuple_(output_tuple) {}
|
||||
output_grad_offset_(output_grad_offset) {}
|
||||
|
||||
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||
const ExecuteParams& params) {
|
||||
@ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
|
||||
stream));
|
||||
|
||||
// Write the output tuple.
|
||||
const int kNumOutputs = 3;
|
||||
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
|
||||
ptrs[0] = output_grad_data.opaque();
|
||||
ptrs[1] = output_grad_scale.opaque();
|
||||
ptrs[2] = output_grad_offset.opaque();
|
||||
se::DeviceMemory<void*> tuple_addr(
|
||||
buffer_allocations.GetDeviceAddress(output_tuple_));
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream,
|
||||
params.deferred_host_callbacks);
|
||||
|
||||
if (!stream->ok()) {
|
||||
return InternalError("BatchNormalizationBackward call failed.");
|
||||
}
|
||||
|
||||
@ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
const BufferAllocation::Slice& offset,
|
||||
const BufferAllocation::Slice& output_data,
|
||||
const BufferAllocation::Slice& output_mean,
|
||||
const BufferAllocation::Slice& output_inv_stddev,
|
||||
const BufferAllocation::Slice& output_tuple);
|
||||
const BufferAllocation::Slice& output_inv_stddev);
|
||||
|
||||
CudnnBatchNormForwardTrainingThunk(
|
||||
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
||||
@ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
||||
BufferAllocation::Slice output_data_;
|
||||
BufferAllocation::Slice output_mean_;
|
||||
BufferAllocation::Slice output_inv_stddev_;
|
||||
BufferAllocation::Slice output_tuple_;
|
||||
};
|
||||
|
||||
class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
public:
|
||||
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
|
||||
CudnnBatchNormConfig&& config,
|
||||
const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale,
|
||||
const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& inv_stddev,
|
||||
const BufferAllocation::Slice& grad_output,
|
||||
const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset,
|
||||
const BufferAllocation::Slice& output_tuple);
|
||||
CudnnBatchNormBackwardThunk(
|
||||
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
|
||||
const BufferAllocation::Slice& operand,
|
||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
|
||||
const BufferAllocation::Slice& inv_stddev,
|
||||
const BufferAllocation::Slice& grad_output,
|
||||
const BufferAllocation::Slice& output_grad_data,
|
||||
const BufferAllocation::Slice& output_grad_scale,
|
||||
const BufferAllocation::Slice& output_grad_offset);
|
||||
|
||||
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
|
||||
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
|
||||
@ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk {
|
||||
BufferAllocation::Slice output_grad_data_;
|
||||
BufferAllocation::Slice output_grad_scale_;
|
||||
BufferAllocation::Slice output_grad_offset_;
|
||||
BufferAllocation::Slice output_tuple_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
@ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||
/*output_data=*/output_data,
|
||||
/*output_mean=*/output_mean,
|
||||
/*output_inv_stddev=*/output_inv_stddev,
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
||||
/*output_inv_stddev=*/output_inv_stddev));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
|
||||
/*output_grad_data=*/output_grad_data,
|
||||
/*output_grad_scale=*/output_grad_scale,
|
||||
/*output_grad_offset=*/output_grad_offset,
|
||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
||||
/*output_grad_offset=*/output_grad_offset));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user