From b0597e4175c97741b090b19b1548da0b2d2aa914 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 29 May 2019 15:41:21 +0000 Subject: [PATCH] Adding ROCm support for the batch_norm op --- tensorflow/core/kernels/batch_norm_op.cc | 8 ++++---- tensorflow/core/kernels/batch_norm_op_gpu.cu.cc | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc index 609ddd68caf..4a03abbba49 100644 --- a/tensorflow/core/kernels/batch_norm_op.cc +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -175,7 +175,7 @@ TF_CALL_float(REGISTER_KERNEL); TF_CALL_double(REGISTER_KERNEL); #undef REGISTER_KERNEL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -206,7 +206,7 @@ TF_CALL_half(REGISTER_GPU_KERNEL); TF_CALL_float(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if TENSORFLOW_USE_SYCL #define REGISTER_KERNEL(T) \ @@ -231,7 +231,7 @@ TF_CALL_float(REGISTER_KERNEL); TF_CALL_double(REGISTER_KERNEL); #undef REGISTER_KERNEL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -265,7 +265,7 @@ TF_CALL_half(REGISTER_GPU_KERNEL); TF_CALL_float(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if TENSORFLOW_USE_SYCL #define REGISTER_KERNEL(T) \ diff --git a/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc b/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc index 1c4184cf6dd..e57cb16a620 100644 --- a/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc +++ b/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -31,4 +31,4 @@ template struct functor::BatchNormGrad; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM