diff --git a/tensorflow/core/kernels/diag_op_gpu.cu.cc b/tensorflow/core/kernels/diag_op_gpu.cu.cc index 9878f347d2a..684f00ea61d 100644 --- a/tensorflow/core/kernels/diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/diag_op_gpu.cu.cc @@ -33,15 +33,12 @@ __global__ void DiagCudaKernel(const int num_threads, const T* in, T* out) { CUDA_1D_KERNEL_LOOP(index, num_threads) { - out[(1 + size) * index] = in[index]; - } -} - -template -__global__ void ZeroCudaKernel(const int num_threads, - T* out) { - CUDA_1D_KERNEL_LOOP(index, num_threads) { - out[index] = T(0); + // Fill the diagonal elements or set to zero in other place. + if (index % (1 + size) == 0) { + out[index] = in[index / (1 + size)]; + } else { + out[index] = T(0); + } } } @@ -50,39 +47,30 @@ struct DiagFunctor { EIGEN_ALWAYS_INLINE Status operator() (OpKernelContext* context, const int64 size, const T* in, T* out) { - // CudaLaunchConfig uses an int for virtual_thread_count, - // so this may overflow in extreme cases. - if (size && (size * size / size) != size) { - return errors::Internal( - "DiagOp got input size too large."); - } - // Empty tensor couldn't launch the kernel. if (size == 0) { return Status::OK(); } - const GPUDevice& device = context->eigen_device(); - // Set output memory with zero elements. - CudaLaunchConfig zero_config = GetCudaLaunchConfig(size*size, device); - ZeroCudaKernel<<>>( - zero_config.virtual_thread_count, out); - auto err = cudaGetLastError(); - if (err != cudaSuccess) { + // CudaLaunchConfig uses an int for virtual_thread_count, + // so this may overflow for `size*size` in extreme cases, + // here is checking the multiplication overflow for integer. + if (size && (int(size * size) / size) != size) { return errors::Internal( - "Could not launch DiagOp kernel: ", - cudaGetErrorString(err), "."); + "DiagOp got input size too large."); } + int virtual_thread_count = int(size * size); - // Fill the diagonal elements - CudaLaunchConfig diag_config = GetCudaLaunchConfig(size, device); + // Launch the GPU kernel. + const GPUDevice& device = context->eigen_device(); + CudaLaunchConfig diag_config = GetCudaLaunchConfig( + virtual_thread_count, device); DiagCudaKernel<<>>( diag_config.virtual_thread_count, size, in, out); - err = cudaGetLastError(); + + auto err = cudaGetLastError(); if (err != cudaSuccess) { return errors::Internal( "Could not launch DiagOp kernel: ", @@ -127,6 +115,7 @@ struct DiagPartFunctor { diag_config.thread_per_block, 0, device.stream()>>>( diag_config.virtual_thread_count, size, in, out); + auto err = cudaGetLastError(); if (err != cudaSuccess) { return errors::Internal(