Merge two GPU kernel launching to one in DiagOp. ()

This commit is contained in:
Jinze Bai 2017-10-21 07:12:31 +08:00 committed by Vijay Vasudevan
parent c0ca50a477
commit 9c825d32c9

View File

@ -33,15 +33,12 @@ __global__ void DiagCudaKernel(const int num_threads,
const T* in,
T* out) {
CUDA_1D_KERNEL_LOOP(index, num_threads) {
out[(1 + size) * index] = in[index];
}
}
template <typename T>
__global__ void ZeroCudaKernel(const int num_threads,
T* out) {
CUDA_1D_KERNEL_LOOP(index, num_threads) {
out[index] = T(0);
// Fill the diagonal elements or set to zero in other place.
if (index % (1 + size) == 0) {
out[index] = in[index / (1 + size)];
} else {
out[index] = T(0);
}
}
}
@ -50,39 +47,30 @@ struct DiagFunctor<GPUDevice, T> {
EIGEN_ALWAYS_INLINE Status
operator() (OpKernelContext* context, const int64 size,
const T* in, T* out) {
// CudaLaunchConfig uses an int for virtual_thread_count,
// so this may overflow in extreme cases.
if (size && (size * size / size) != size) {
return errors::Internal(
"DiagOp got input size too large.");
}
// Empty tensor couldn't launch the kernel.
if (size == 0) {
return Status::OK();
}
const GPUDevice& device = context->eigen_device<GPUDevice>();
// Set output memory with zero elements.
CudaLaunchConfig zero_config = GetCudaLaunchConfig(size*size, device);
ZeroCudaKernel<<<zero_config.block_count,
zero_config.thread_per_block,
0, device.stream()>>>(
zero_config.virtual_thread_count, out);
auto err = cudaGetLastError();
if (err != cudaSuccess) {
// CudaLaunchConfig uses an int for virtual_thread_count,
// so this may overflow for `size*size` in extreme cases,
// here is checking the multiplication overflow for integer.
if (size && (int(size * size) / size) != size) {
return errors::Internal(
"Could not launch DiagOp kernel: ",
cudaGetErrorString(err), ".");
"DiagOp got input size too large.");
}
int virtual_thread_count = int(size * size);
// Fill the diagonal elements
CudaLaunchConfig diag_config = GetCudaLaunchConfig(size, device);
// Launch the GPU kernel.
const GPUDevice& device = context->eigen_device<GPUDevice>();
CudaLaunchConfig diag_config = GetCudaLaunchConfig(
virtual_thread_count, device);
DiagCudaKernel<<<diag_config.block_count,
diag_config.thread_per_block,
0, device.stream()>>>(
diag_config.virtual_thread_count, size, in, out);
err = cudaGetLastError();
auto err = cudaGetLastError();
if (err != cudaSuccess) {
return errors::Internal(
"Could not launch DiagOp kernel: ",
@ -127,6 +115,7 @@ struct DiagPartFunctor<GPUDevice, T> {
diag_config.thread_per_block,
0, device.stream()>>>(
diag_config.virtual_thread_count, size, in, out);
auto err = cudaGetLastError();
if (err != cudaSuccess) {
return errors::Internal(