[XLA:GPU] Eliminate tuple population from batch norm thunks
- These tuples on the GPU side should not be directly used by anyone since XLA should folded the GetTupleElement into which these tuple feeds. PiperOrigin-RevId: 346672532 Change-Id: Ia5980e14ddd157d84fe60cb40dd4ebc2f5a77c9e
This commit is contained in:
parent
4d1c107bef
commit
5a08d776c6
@ -76,8 +76,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
|||||||
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset,
|
||||||
const BufferAllocation::Slice& output_data,
|
const BufferAllocation::Slice& output_data,
|
||||||
const BufferAllocation::Slice& output_mean,
|
const BufferAllocation::Slice& output_mean,
|
||||||
const BufferAllocation::Slice& output_inv_stddev,
|
const BufferAllocation::Slice& output_inv_stddev)
|
||||||
const BufferAllocation::Slice& output_tuple)
|
|
||||||
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
|
: Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info),
|
||||||
config_(std::move(config)),
|
config_(std::move(config)),
|
||||||
operand_(operand),
|
operand_(operand),
|
||||||
@ -85,8 +84,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
|
|||||||
offset_(offset),
|
offset_(offset),
|
||||||
output_data_(output_data),
|
output_data_(output_data),
|
||||||
output_mean_(output_mean),
|
output_mean_(output_mean),
|
||||||
output_inv_stddev_(output_inv_stddev),
|
output_inv_stddev_(output_inv_stddev) {}
|
||||||
output_tuple_(output_tuple) {}
|
|
||||||
|
|
||||||
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||||
const ExecuteParams& params) {
|
const ExecuteParams& params) {
|
||||||
@ -110,16 +108,6 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
|||||||
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
|
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(offset_)),
|
||||||
&stream));
|
&stream));
|
||||||
|
|
||||||
// Write the output tuple.
|
|
||||||
const int kNumOutputs = 3;
|
|
||||||
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
|
|
||||||
ptrs[0] = output_data.opaque();
|
|
||||||
ptrs[1] = output_mean.opaque();
|
|
||||||
ptrs[2] = output_inv_stddev.opaque();
|
|
||||||
se::DeviceMemory<void*> tuple_addr(
|
|
||||||
buffer_allocations.GetDeviceAddress(output_tuple_));
|
|
||||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
|
|
||||||
params.deferred_host_callbacks);
|
|
||||||
if (!stream.ok()) {
|
if (!stream.ok()) {
|
||||||
return InternalError("BatchNormalizationTraining call failed.");
|
return InternalError("BatchNormalizationTraining call failed.");
|
||||||
}
|
}
|
||||||
@ -134,8 +122,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
|||||||
const BufferAllocation::Slice& grad_output,
|
const BufferAllocation::Slice& grad_output,
|
||||||
const BufferAllocation::Slice& output_grad_data,
|
const BufferAllocation::Slice& output_grad_data,
|
||||||
const BufferAllocation::Slice& output_grad_scale,
|
const BufferAllocation::Slice& output_grad_scale,
|
||||||
const BufferAllocation::Slice& output_grad_offset,
|
const BufferAllocation::Slice& output_grad_offset)
|
||||||
const BufferAllocation::Slice& output_tuple)
|
|
||||||
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
|
: Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info),
|
||||||
config_(std::move(config)),
|
config_(std::move(config)),
|
||||||
operand_(operand),
|
operand_(operand),
|
||||||
@ -145,8 +132,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
|
|||||||
grad_output_(grad_output),
|
grad_output_(grad_output),
|
||||||
output_grad_data_(output_grad_data),
|
output_grad_data_(output_grad_data),
|
||||||
output_grad_scale_(output_grad_scale),
|
output_grad_scale_(output_grad_scale),
|
||||||
output_grad_offset_(output_grad_offset),
|
output_grad_offset_(output_grad_offset) {}
|
||||||
output_tuple_(output_tuple) {}
|
|
||||||
|
|
||||||
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||||
const ExecuteParams& params) {
|
const ExecuteParams& params) {
|
||||||
@ -172,17 +158,6 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
|||||||
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
|
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(inv_stddev_)),
|
||||||
stream));
|
stream));
|
||||||
|
|
||||||
// Write the output tuple.
|
|
||||||
const int kNumOutputs = 3;
|
|
||||||
auto ptrs = absl::make_unique<void*[]>(kNumOutputs);
|
|
||||||
ptrs[0] = output_grad_data.opaque();
|
|
||||||
ptrs[1] = output_grad_scale.opaque();
|
|
||||||
ptrs[2] = output_grad_offset.opaque();
|
|
||||||
se::DeviceMemory<void*> tuple_addr(
|
|
||||||
buffer_allocations.GetDeviceAddress(output_tuple_));
|
|
||||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream,
|
|
||||||
params.deferred_host_callbacks);
|
|
||||||
|
|
||||||
if (!stream->ok()) {
|
if (!stream->ok()) {
|
||||||
return InternalError("BatchNormalizationBackward call failed.");
|
return InternalError("BatchNormalizationBackward call failed.");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -82,8 +82,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
|||||||
const BufferAllocation::Slice& offset,
|
const BufferAllocation::Slice& offset,
|
||||||
const BufferAllocation::Slice& output_data,
|
const BufferAllocation::Slice& output_data,
|
||||||
const BufferAllocation::Slice& output_mean,
|
const BufferAllocation::Slice& output_mean,
|
||||||
const BufferAllocation::Slice& output_inv_stddev,
|
const BufferAllocation::Slice& output_inv_stddev);
|
||||||
const BufferAllocation::Slice& output_tuple);
|
|
||||||
|
|
||||||
CudnnBatchNormForwardTrainingThunk(
|
CudnnBatchNormForwardTrainingThunk(
|
||||||
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
const CudnnBatchNormForwardTrainingThunk&) = delete;
|
||||||
@ -100,22 +99,19 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
|
|||||||
BufferAllocation::Slice output_data_;
|
BufferAllocation::Slice output_data_;
|
||||||
BufferAllocation::Slice output_mean_;
|
BufferAllocation::Slice output_mean_;
|
||||||
BufferAllocation::Slice output_inv_stddev_;
|
BufferAllocation::Slice output_inv_stddev_;
|
||||||
BufferAllocation::Slice output_tuple_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class CudnnBatchNormBackwardThunk : public Thunk {
|
class CudnnBatchNormBackwardThunk : public Thunk {
|
||||||
public:
|
public:
|
||||||
CudnnBatchNormBackwardThunk(ThunkInfo thunk_info,
|
CudnnBatchNormBackwardThunk(
|
||||||
CudnnBatchNormConfig&& config,
|
ThunkInfo thunk_info, CudnnBatchNormConfig&& config,
|
||||||
const BufferAllocation::Slice& operand,
|
const BufferAllocation::Slice& operand,
|
||||||
const BufferAllocation::Slice& scale,
|
const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
|
||||||
const BufferAllocation::Slice& mean,
|
const BufferAllocation::Slice& inv_stddev,
|
||||||
const BufferAllocation::Slice& inv_stddev,
|
const BufferAllocation::Slice& grad_output,
|
||||||
const BufferAllocation::Slice& grad_output,
|
const BufferAllocation::Slice& output_grad_data,
|
||||||
const BufferAllocation::Slice& output_grad_data,
|
const BufferAllocation::Slice& output_grad_scale,
|
||||||
const BufferAllocation::Slice& output_grad_scale,
|
const BufferAllocation::Slice& output_grad_offset);
|
||||||
const BufferAllocation::Slice& output_grad_offset,
|
|
||||||
const BufferAllocation::Slice& output_tuple);
|
|
||||||
|
|
||||||
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
|
CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
|
||||||
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
|
CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
|
||||||
@ -133,7 +129,6 @@ class CudnnBatchNormBackwardThunk : public Thunk {
|
|||||||
BufferAllocation::Slice output_grad_data_;
|
BufferAllocation::Slice output_grad_data_;
|
||||||
BufferAllocation::Slice output_grad_scale_;
|
BufferAllocation::Slice output_grad_scale_;
|
||||||
BufferAllocation::Slice output_grad_offset_;
|
BufferAllocation::Slice output_grad_offset_;
|
||||||
BufferAllocation::Slice output_tuple_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
|||||||
@ -258,8 +258,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
|||||||
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
|
||||||
/*output_data=*/output_data,
|
/*output_data=*/output_data,
|
||||||
/*output_mean=*/output_mean,
|
/*output_mean=*/output_mean,
|
||||||
/*output_inv_stddev=*/output_inv_stddev,
|
/*output_inv_stddev=*/output_inv_stddev));
|
||||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,8 +294,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
|||||||
/*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
|
/*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
|
||||||
/*output_grad_data=*/output_grad_data,
|
/*output_grad_data=*/output_grad_data,
|
||||||
/*output_grad_scale=*/output_grad_scale,
|
/*output_grad_scale=*/output_grad_scale,
|
||||||
/*output_grad_offset=*/output_grad_offset,
|
/*output_grad_offset=*/output_grad_offset));
|
||||||
/*output_tuple=*/GetAllocationSlice(*custom_call)));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user