From 83a0c0e427c864890e23ef7af70467baf3e51213 Mon Sep 17 00:00:00 2001
From: Jeff Daily <jeff.daily@amd.com>
Date: Mon, 3 Jun 2019 21:09:03 +0000
Subject: [PATCH] Add ROCm support for remaining cwise ops and tests

cwise_ops.h -- packet access is always false for ROCm.
Some ROCm ops do not yet support complex64 / complex128 types.
---
 tensorflow/core/kernels/cwise_op_clip.cc      |  2 +-
 .../core/kernels/cwise_op_clip_gpu.cu.cc      | 22 +++---
 tensorflow/core/kernels/cwise_op_div.cc       | 17 +++++
 .../core/kernels/cwise_op_gpu_div.cu.cc       |  9 ++-
 .../core/kernels/cwise_op_gpu_zeta.cu.cc      |  4 +-
 tensorflow/core/kernels/cwise_op_mul_1.cc     |  9 ++-
 tensorflow/core/kernels/cwise_op_zeta.cc      |  2 +
 tensorflow/core/kernels/cwise_ops.h           | 49 ++++++++++++-
 tensorflow/core/kernels/cwise_ops_test.cc     | 68 +++++++++----------
 9 files changed, 130 insertions(+), 52 deletions(-)

diff --git a/tensorflow/core/kernels/cwise_op_clip.cc b/tensorflow/core/kernels/cwise_op_clip.cc
index 49b90e855be..c0c71c5f638 100644
--- a/tensorflow/core/kernels/cwise_op_clip.cc
+++ b/tensorflow/core/kernels/cwise_op_clip.cc
@@ -181,7 +181,7 @@ REGISTER_CPU_KERNEL(uint8);
 REGISTER_CPU_KERNEL(uint16);
 #undef REGISTER_CPU_KERNEL
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define REGISTER_GPU_KERNEL(type)                                       \
   REGISTER_KERNEL_BUILDER(                                              \
diff --git a/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc b/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
index 8da2dfb1d5a..55d3033483f 100644
--- a/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #define EIGEN_USE_GPU
 
@@ -26,7 +26,7 @@ namespace tensorflow {
 template <typename T>
 __global__ void UnaryClipCustomKernel(const int32 size_in, const T *in0,
                                       const T *in1, const T *in2, T *out) {
-  CUDA_1D_KERNEL_LOOP(i, size_in) {
+  GPU_1D_KERNEL_LOOP(i, size_in) {
     T value = in2[0] < in0[i] ? in2[0] : in0[i];
     out[i] = value < in1[0] ? in1[0] : value;
   }
@@ -36,7 +36,7 @@ template <typename T>
 __global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
                                             const T *in1, const T *in2,
                                             T *out) {
-  CUDA_1D_KERNEL_LOOP(i, size_in) {
+  GPU_1D_KERNEL_LOOP(i, size_in) {
     T value = in2[i] < in0[i] ? in2[i] : in0[i];
     out[i] = value < in1[0] ? in1[0] : value;
   }
@@ -45,7 +45,7 @@ __global__ void BinaryRightClipCustomKernel(const int32 size_in, const T *in0,
 template <typename T>
 __global__ void BinaryLeftClipCustomKernel(const int32 size_in, const T *in0,
                                            const T *in1, const T *in2, T *out) {
-  CUDA_1D_KERNEL_LOOP(i, size_in) {
+  GPU_1D_KERNEL_LOOP(i, size_in) {
     T value = in2[0] < in0[i] ? in2[0] : in0[i];
     out[i] = value < in1[i] ? in1[i] : value;
   }
@@ -60,9 +60,9 @@ struct UnaryClipOp<GPUDevice, T> {
                   typename TTypes<T>::ConstFlat &in1_flat,
                   typename TTypes<T>::ConstFlat &in2_flat,
                   typename TTypes<T>::Flat &out_flat) const {
-    GpuLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
+    GpuLaunchConfig config = GetGpuLaunchConfig(in0_flat.size(), d);
 
-    TF_CHECK_OK(CudaLaunchKernel(
+    TF_CHECK_OK(GpuLaunchKernel(
         UnaryClipCustomKernel<T>, config.block_count, config.thread_per_block,
         0, d.stream(), in0_flat.size(), in0_flat.data(), in1_flat.data(),
         in2_flat.data(), out_flat.data()));
@@ -76,9 +76,9 @@ struct BinaryRightClipOp<GPUDevice, T> {
                   typename TTypes<T>::ConstFlat &in1_flat,
                   typename TTypes<T>::ConstFlat &in2_flat,
                   typename TTypes<T>::Flat &out_flat) const {
-    GpuLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
+    GpuLaunchConfig config = GetGpuLaunchConfig(in0_flat.size(), d);
 
-    TF_CHECK_OK(CudaLaunchKernel(
+    TF_CHECK_OK(GpuLaunchKernel(
         BinaryRightClipCustomKernel<T>, config.block_count,
         config.thread_per_block, 0, d.stream(), in0_flat.size(),
         in0_flat.data(), in1_flat.data(), in2_flat.data(), out_flat.data()));
@@ -92,9 +92,9 @@ struct BinaryLeftClipOp<GPUDevice, T> {
                   typename TTypes<T>::ConstFlat &in1_flat,
                   typename TTypes<T>::ConstFlat &in2_flat,
                   typename TTypes<T>::Flat &out_flat) const {
-    GpuLaunchConfig config = GetCudaLaunchConfig(in0_flat.size(), d);
+    GpuLaunchConfig config = GetGpuLaunchConfig(in0_flat.size(), d);
 
-    TF_CHECK_OK(CudaLaunchKernel(
+    TF_CHECK_OK(GpuLaunchKernel(
         BinaryLeftClipCustomKernel<T>, config.block_count,
         config.thread_per_block, 0, d.stream(), in0_flat.size(),
         in0_flat.data(), in1_flat.data(), in2_flat.data(), out_flat.data()));
@@ -131,4 +131,4 @@ INSTANTIATE_GPU(uint16);
 }  // namespace functor
 }  // namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index 169041a4ad4..2917bca121e 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -27,15 +27,32 @@ REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
 REGISTER5(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
           double, complex64, complex128);
 
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// ROCM TODO: re-enable complex64 / complex128 after compiler fix
 #if GOOGLE_CUDA
 REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
           uint16, int16, int64, complex64, complex128);
+#elif TENSORFLOW_USE_ROCM
+REGISTER7(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
+          uint16, int16, int64);
+#endif
 REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
           int64);
+// ROCM TODO: re-enable complex64 / complex128 after compiler fix
+#if GOOGLE_CUDA
 REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
           complex64, complex128);
+#elif TENSORFLOW_USE_ROCM
+REGISTER3(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double);
+#endif
+
+#if GOOGLE_CUDA
 REGISTER5(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
           double, complex64, complex128);
+#elif TENSORFLOW_USE_ROCM
+REGISTER3(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
+          double);
+#endif
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
index b50f4abb701..dc5e56867b5 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
 
@@ -21,8 +21,13 @@ namespace tensorflow {
 namespace functor {
 DEFINE_BINARY10(div, Eigen::half, float, double, uint8, uint16, int16, int32,
                 int64, complex64, complex128);
+#if GOOGLE_CUDA
 DEFINE_BINARY5(div_no_nan, Eigen::half, float, double, complex64, complex128);
+#elif TENSORFLOW_USE_ROCM
+// ROCM TODO: fix compiler error for complex64 / complex128 division
+DEFINE_BINARY3(div_no_nan, Eigen::half, float, double);
+#endif
 }  // namespace functor
 }  // namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc
index 8f64a904473..41499ea096f 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_zeta.cu.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
 
@@ -24,4 +24,4 @@ DEFINE_BINARY2(polygamma, float, double);
 }  // namespace functor
 }  // namespace tensorflow
 
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/cwise_op_mul_1.cc b/tensorflow/core/kernels/cwise_op_mul_1.cc
index 8ec618cf88b..2765005615a 100644
--- a/tensorflow/core/kernels/cwise_op_mul_1.cc
+++ b/tensorflow/core/kernels/cwise_op_mul_1.cc
@@ -29,7 +29,7 @@ REGISTER5(BinaryOp, CPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
 REGISTER(BinaryOp, CPU, "Mul", functor::mul, int32);
 #endif  // __ANDROID_TYPES_SLIM__
 
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER4(BinaryOp, GPU, "Mul", functor::mul, Eigen::half, float, double,
           uint8);
 // A special GPU kernel for int32.
@@ -42,8 +42,15 @@ REGISTER_KERNEL_BUILDER(Name("Mul")
                             .HostMemory("z")
                             .TypeConstraint<int32>("T"),
                         BinaryOp<CPUDevice, functor::mul<int32>>);
+#endif
+
+#if GOOGLE_CUDA
 REGISTER5(BinaryOp, GPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
           double, complex64, complex128);
+#elif TENSORFLOW_USE_ROCM
+// ROCM TODO: re-enable complex64 / complex128 after compiler fix
+REGISTER3(BinaryOp, GPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
+          double);
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/cwise_op_zeta.cc b/tensorflow/core/kernels/cwise_op_zeta.cc
index dc064eec5f7..c3a3a38ae40 100644
--- a/tensorflow/core/kernels/cwise_op_zeta.cc
+++ b/tensorflow/core/kernels/cwise_op_zeta.cc
@@ -21,6 +21,8 @@ REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double);
 
 #if GOOGLE_CUDA
 REGISTER2(BinaryOp, GPU, "Zeta", functor::zeta, float, double);
+#endif
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER2(BinaryOp, GPU, "Polygamma", functor::polygamma, float, double);
 #endif
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 48297bc2848..2849e39c836 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -176,7 +176,11 @@ template <typename T>
 struct functor_traits<div_no_nan_op<T>> {
   enum {
     Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = true,
+#endif // TENSORFLOW_USE_ROCM
   };
 };
 
@@ -189,7 +193,11 @@ template <typename T>
 struct functor_traits<mul_no_nan_op<T>> {
   enum {
     Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = true,
+#endif // TENSORFLOW_USE_ROCM
   };
 };
 
@@ -227,7 +235,11 @@ template <typename Tout, typename Tin, typename Binary>
 struct functor_traits<scalar_left<Tout, Tin, Binary>> {
   enum {
     Cost = functor_traits<Binary>::Cost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = functor_traits<Binary>::PacketAccess,
+#endif // TENSORFLOW_USE_ROCM
   };
 };
 
@@ -257,7 +269,11 @@ template <typename Tout, typename Tin, typename Binary>
 struct functor_traits<scalar_right<Tout, Tin, Binary>> {
   enum {
     Cost = functor_traits<Binary>::Cost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = functor_traits<Binary>::PacketAccess,
+#endif // TENSORFLOW_USE_ROCM
   };
 };
 
@@ -350,8 +366,13 @@ struct google_floor_div {
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
                                                            const T& y) const {
     if ((x < T(0)) != (y < T(0))) {
+#if defined(__HIP_DEVICE_COMPILE__)
+      T abs_x = (x < T(0)) ? -x : x;
+      T abs_y = (y < T(0)) ? -y : y;
+#else
       T abs_x = std::abs(x);
       T abs_y = std::abs(y);
+#endif
       return -(abs_x + abs_y - 1) / abs_y;
     } else {
       return x / y;
@@ -392,7 +413,11 @@ struct functor_traits<google_floor_div<Scalar>> {
     Cost = 2 * Eigen::internal::scalar_div_cost<
                    Scalar, packet_traits<Scalar>::HasDiv>::value +
            NumTraits<Scalar>::AddCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = packet_traits<Scalar>::HasDiv
+#endif
   };
 };
 
@@ -415,8 +440,12 @@ struct functor_traits<google_floor_div_real<Scalar>> {
     Cost = 2 * Eigen::internal::scalar_div_cost<
                    Scalar, packet_traits<Scalar>::HasDiv>::value +
            2 * NumTraits<Scalar>::AddCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess =
         packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor
+#endif
   };
 };
 
@@ -513,7 +542,11 @@ struct functor_traits<scalar_round_op_google<Scalar>> {
   enum {
     Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0
                                                : 4 * NumTraits<Scalar>::AddCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = Eigen::NumTraits<Scalar>::IsInteger
+#endif
   };
 };
 
@@ -551,7 +584,11 @@ template <typename Scalar, bool IsInteger>
 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> {
   enum {
     Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor
+#endif
   };
 };
 
@@ -604,7 +641,11 @@ struct functor_traits<xlogy_op<Scalar>> {
   enum {
     Cost = functor_traits<scalar_log_op<Scalar>>::Cost +
            Eigen::NumTraits<Scalar>::MulCost,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess
+#endif
   };
 };
 
@@ -635,7 +676,11 @@ struct functor_traits<xdivy_op<Scalar>> {
         Eigen::NumTraits<Scalar>::AddCost +
         Eigen::internal::scalar_div_cost<Scalar,
                                          packet_traits<Scalar>::HasDiv>::value,
+#if TENSORFLOW_USE_ROCM
+    PacketAccess = false,
+#else
     PacketAccess = packet_traits<Scalar>::HasDiv
+#endif
   };
 };
 
@@ -855,7 +900,7 @@ struct scalar_rint_op {
   EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
   operator()(const Scalar& a) const {
-#if defined(__CUDACC__)
+#if defined(__CUDACC__) || defined(__HIPCC__)
     return ::rint(a);
 #elif defined(__ANDROID__)
     return rint(a);
@@ -985,6 +1030,8 @@ struct scalar_atan2_op {
   operator()(const Scalar& y, const Scalar& x) const {
 #if GOOGLE_CUDA
     return std::atan2(y, x);
+#elif TENSORFLOW_USE_ROCM
+    return ::atan2(y, x);
 #else
     return std::atan2(y, x);
 #endif
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc
index acf7cc28993..73ba6d5968b 100644
--- a/tensorflow/core/kernels/cwise_ops_test.cc
+++ b/tensorflow/core/kernels/cwise_ops_test.cc
@@ -53,38 +53,38 @@ int ColsFromArg(int arg) { return (arg % kRows); }
   BENCHMARK(BM_##DEVICE##_##FUNC##_##TYPE)->Range(4 << 10, 1 << 20);
 
 BM_UNARY(cpu, Floor, float, DT_FLOAT);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(gpu, Floor, float, DT_FLOAT);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_UNARY(sycl, Floor, float, DT_FLOAT);
 #endif  // TENSORFLOW_USE_SYCL
 
 BM_UNARY(cpu, Floor, double, DT_DOUBLE);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(gpu, Floor, double, DT_DOUBLE);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_UNARY(sycl, Floor, double, DT_DOUBLE);
 #endif  // TENSORFLOW_USE_SYCL
 
 BM_UNARY(cpu, Conj, std::complex<float>, DT_COMPLEX64);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(gpu, Conj, std::complex<float>, DT_COMPLEX64);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(cpu, Conj, std::complex<double>, DT_COMPLEX128);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(gpu, Conj, std::complex<double>, DT_COMPLEX128);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 BM_UNARY(cpu, Rint, double, DT_DOUBLE);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(gpu, Rint, double, DT_DOUBLE);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(cpu, Rint, float, DT_FLOAT);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_UNARY(gpu, Rint, float, DT_FLOAT);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 // data func scalar.
 Graph* BinaryScalar(int num, const string& func) {
@@ -113,17 +113,17 @@ Graph* BinaryScalar(int num, const string& func) {
       ->Arg(1048576);
 
 BM_BINARY_SCALAR(cpu, Less);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BINARY_SCALAR(gpu, Less);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_BINARY_SCALAR(sycl, Less);
 #endif  // TENSORFLOW_USE_SYCL
 
 BM_BINARY_SCALAR(cpu, Add);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BINARY_SCALAR(gpu, Add);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_BINARY_SCALAR(sycl, Add);
 #endif  // TENSORFLOW_USE_SYCL
@@ -173,13 +173,13 @@ Graph* BiasAdd(int rows, int cols, DataType type) {
 
 using Eigen::half;
 BM_BIAS_ADD_ALL(cpu, float, DT_FLOAT);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_ALL(gpu, float, DT_FLOAT);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_ALL(cpu, half, DT_HALF);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_ALL(gpu, half, DT_HALF);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef BM_BIAS_ADD_ALL
 #undef BM_BIAS_ADD
 
@@ -227,18 +227,18 @@ Graph* BiasAddGrad(int rows, int cols, int channels, DataType type,
   BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 1);
 
 using Eigen::half;
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, float, DT_FLOAT);
 BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, half, DT_HALF);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, float, DT_FLOAT);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, float, DT_FLOAT);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, half, DT_HALF);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef BM_BIAS_ADD_GRAD_ALL
 #undef BM_BIAS_ADD_GRAD
 
@@ -285,9 +285,9 @@ Graph* BcastAdd(int rows, int cols, int dim) {
   BM_BCAST_ADD_ROW(DEVICE, 2048, 512); \
   BM_BCAST_ADD_ROW(DEVICE, 4096, 512);
 BM_BCAST_ADD_ROW_ALL(cpu);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BCAST_ADD_ROW_ALL(gpu);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_BCAST_ADD_ROW_ALL(sycl);
 #endif  // TENSORFLOW_USE_SYCL
@@ -312,9 +312,9 @@ BM_BCAST_ADD_ROW_ALL(sycl);
   BM_BCAST_ADD_COL(DEVICE, 2048, 512); \
   BM_BCAST_ADD_COL(DEVICE, 4096, 512);
 BM_BCAST_ADD_COL_ALL(cpu);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BCAST_ADD_COL_ALL(gpu);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_BCAST_ADD_COL_ALL(sycl);
 #endif  // TENSORFLOW_USE_SYCL
@@ -340,9 +340,9 @@ BM_BCAST_ADD_COL_ALL(sycl);
   BM_BCAST_ADD_CROSS_RC(DEVICE, 2048, 512); \
   BM_BCAST_ADD_CROSS_RC(DEVICE, 4096, 512);
 BM_BCAST_ADD_CROSS_RC_ALL(cpu);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BCAST_ADD_CROSS_RC_ALL(gpu);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_BCAST_ADD_CROSS_RC_ALL(sycl);
 #endif  // TENSORFLOW_USE_SYCL
@@ -368,9 +368,9 @@ BM_BCAST_ADD_CROSS_RC_ALL(sycl);
   BM_BCAST_ADD_CROSS_CR(DEVICE, 2048, 512); \
   BM_BCAST_ADD_CROSS_CR(DEVICE, 4096, 512);
 BM_BCAST_ADD_CROSS_CR_ALL(cpu);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 BM_BCAST_ADD_CROSS_CR_ALL(gpu);
-#endif  // GOOGLE_CUDA
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #ifdef TENSORFLOW_USE_SYCL
 BM_BCAST_ADD_CROSS_CR_ALL(sycl);
 #endif  // TENSORFLOW_USE_SYCL