From 9c825d32c9423980e1b263a50360e03e833b69a6 Mon Sep 17 00:00:00 2001
From: Jinze Bai <baijinze1994@163.com>
Date: Sat, 21 Oct 2017 07:12:31 +0800
Subject: [PATCH] Merge two GPU kernel launching to one in DiagOp. (#13859)

---
 tensorflow/core/kernels/diag_op_gpu.cu.cc | 49 +++++++++--------------
 1 file changed, 19 insertions(+), 30 deletions(-)

diff --git a/tensorflow/core/kernels/diag_op_gpu.cu.cc b/tensorflow/core/kernels/diag_op_gpu.cu.cc
index 9878f347d2a..684f00ea61d 100644
--- a/tensorflow/core/kernels/diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/diag_op_gpu.cu.cc
@@ -33,15 +33,12 @@ __global__ void DiagCudaKernel(const int num_threads,
                                const T* in,
                                T* out) {
   CUDA_1D_KERNEL_LOOP(index, num_threads) {
-    out[(1 + size) * index] = in[index];
-  }
-}
-
-template <typename T>
-__global__ void ZeroCudaKernel(const int num_threads,
-                               T* out) {
-  CUDA_1D_KERNEL_LOOP(index, num_threads) {
-    out[index] = T(0);
+    // Fill the diagonal elements or set to zero in other place. 
+    if (index % (1 + size) == 0) {
+      out[index] = in[index / (1 + size)];
+    } else {
+      out[index] = T(0);
+    }
   }
 }
 
@@ -50,39 +47,30 @@ struct DiagFunctor<GPUDevice, T> {
   EIGEN_ALWAYS_INLINE Status
   operator() (OpKernelContext* context, const int64 size,
               const T* in, T* out) {
-    // CudaLaunchConfig uses an int for virtual_thread_count,
-    // so this may overflow in extreme cases.
-    if (size && (size * size / size) != size) {
-      return errors::Internal(
-          "DiagOp got input size too large.");
-    }
-
     // Empty tensor couldn't launch the kernel.
     if (size == 0) {
       return Status::OK();
     }
-    const GPUDevice& device = context->eigen_device<GPUDevice>();
 
-    // Set output memory with zero elements.
-    CudaLaunchConfig zero_config = GetCudaLaunchConfig(size*size, device);
-    ZeroCudaKernel<<<zero_config.block_count,
-                     zero_config.thread_per_block,
-                     0, device.stream()>>>(
-        zero_config.virtual_thread_count, out);
-    auto err = cudaGetLastError();
-    if (err != cudaSuccess) {
+    // CudaLaunchConfig uses an int for virtual_thread_count,
+    // so this may overflow for `size*size` in extreme cases,
+    // here is checking the multiplication overflow for integer.
+    if (size && (int(size * size) / size) != size) {
       return errors::Internal(
-          "Could not launch DiagOp kernel: ",
-          cudaGetErrorString(err), ".");
+          "DiagOp got input size too large.");
     }
+    int virtual_thread_count = int(size * size);
 
-    // Fill the diagonal elements
-    CudaLaunchConfig diag_config = GetCudaLaunchConfig(size, device);
+    // Launch the GPU kernel.
+    const GPUDevice& device = context->eigen_device<GPUDevice>();
+    CudaLaunchConfig diag_config = GetCudaLaunchConfig(
+        virtual_thread_count, device);
     DiagCudaKernel<<<diag_config.block_count,
                      diag_config.thread_per_block,
                      0, device.stream()>>>(
         diag_config.virtual_thread_count, size, in, out);
-    err = cudaGetLastError();
+
+    auto err = cudaGetLastError();
     if (err != cudaSuccess) {
       return errors::Internal(
           "Could not launch DiagOp kernel: ",
@@ -127,6 +115,7 @@ struct DiagPartFunctor<GPUDevice, T> {
                      diag_config.thread_per_block,
                      0, device.stream()>>>(
         diag_config.virtual_thread_count, size, in, out);
+
     auto err = cudaGetLastError();
     if (err != cudaSuccess) {
       return errors::Internal(