From 24974ed8bb25b26a269cec126437fbb8cb2e943a Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Tue, 2 Jul 2019 19:18:21 +0000 Subject: [PATCH] Adding ROCm support for the batch_matmul op --- .../core/kernels/batch_matmul_op_complex.cc | 4 ++-- .../core/kernels/batch_matmul_op_impl.h | 24 +++++++++---------- .../core/kernels/batch_matmul_op_real.cc | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc index 16913986d21..2cf163be0d4 100644 --- a/tensorflow/core/kernels/batch_matmul_op_complex.cc +++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc @@ -20,9 +20,9 @@ namespace tensorflow { TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 1798b272e0c..84f7571d6a4 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -42,9 +42,9 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/stream_executor.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { @@ -248,22 +248,22 @@ struct LaunchBatchMatMul { } }; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace { template -se::DeviceMemory AsDeviceMemory(const T* cuda_memory) { - se::DeviceMemoryBase wrapped(const_cast(cuda_memory)); +se::DeviceMemory AsDeviceMemory(const T* gpu_memory) { + se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); se::DeviceMemory typed(wrapped); return typed; } -class CublasScratchAllocator : public se::ScratchAllocator { +class BlasScratchAllocator : public se::ScratchAllocator { public: using Stream = se::Stream; using DeviceMemoryBytes = se::DeviceMemory; - CublasScratchAllocator(OpKernelContext* context) : context_(context) {} + BlasScratchAllocator(OpKernelContext* context) : context_(context) {} int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; } @@ -357,7 +357,7 @@ struct LaunchBatchMatMul { typedef Scalar Coefficient; - // Cublas does + // Blas does // C = A x B // where A, B and C are assumed to be in column major. // We want the output to be in row-major, so we can compute @@ -406,7 +406,7 @@ struct LaunchBatchMatMul { } } } else { - CublasScratchAllocator scratch_allocator(context); + BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = stream ->ThenBlasGemmBatchedWithScratch( @@ -493,7 +493,7 @@ struct LaunchBatchMatMul { typedef float Coefficient; - // Cublas does + // Blas does // C = A x B // where A, B and C are assumed to be in column major. // We want the output to be in row-major, so we can compute @@ -517,7 +517,7 @@ struct LaunchBatchMatMul { ", k=", k)); } } else { - CublasScratchAllocator scratch_allocator(context); + BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = stream ->ThenBlasGemmBatchedWithScratch( @@ -537,7 +537,7 @@ struct LaunchBatchMatMul { } }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL template diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index 3870c227a23..12c1f48a3c8 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -27,11 +27,11 @@ TF_CALL_half(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_float(REGISTER_BATCH_MATMUL_GPU); TF_CALL_double(REGISTER_BATCH_MATMUL_GPU); TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL);