From 124daa37281e654522c212f8bae0c1743e7437fa Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Mon, 10 Jun 2019 14:51:44 +0000 Subject: [PATCH] Adding ROCm support for scatter ops --- tensorflow/core/kernels/scatter_functor.cc | 4 +-- .../core/kernels/scatter_functor_gpu.cu.cc | 4 +-- .../core/kernels/scatter_functor_gpu.cu.h | 28 +++++++++---------- tensorflow/core/kernels/scatter_nd_op.cc | 12 ++++---- .../core/kernels/scatter_nd_op_gpu.cu.cc | 18 ++++++------ tensorflow/core/kernels/scatter_op.cc | 4 +-- tensorflow/core/kernels/scatter_op_gpu.cu.cc | 4 +-- 7 files changed, 37 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc index cf5408123fb..f17d8759d20 100644 --- a/tensorflow/core/kernels/scatter_functor.cc +++ b/tensorflow/core/kernels/scatter_functor.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/scatter_functor.h" #include "tensorflow/core/framework/register_types.h" @@ -68,4 +68,4 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); #include "tensorflow/core/kernels/scatter_functor.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc index bdc878594a3..7bfd0051de9 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -54,4 +54,4 @@ DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h index 6c195e59e20..11baad0d585 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_ #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -41,32 +41,32 @@ struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> { template <typename T> struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> { - __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); } + __device__ void operator()(T* dest, T src) const { GpuAtomicAdd(dest, src); } }; template <typename T> struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> { - __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); } + __device__ void operator()(T* dest, T src) const { GpuAtomicSub(dest, src); } }; template <typename T> struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> { - __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); } + __device__ void operator()(T* dest, T src) const { GpuAtomicMul(dest, src); } }; template <typename T> struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> { - __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); } + __device__ void operator()(T* dest, T src) const { GpuAtomicDiv(dest, src); } }; template <typename T> struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> { - __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); } + __device__ void operator()(T* dest, T src) const { GpuAtomicMin(dest, src); } }; template <typename T> struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> { - __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); } + __device__ void operator()(T* dest, T src) const { GpuAtomicMax(dest, src); } }; template <typename T, typename Index, scatter_op::UpdateOp op> @@ -76,7 +76,7 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates, Index indices_size) { Index update_block = updates_size / indices_size; ScatterOpKernelBody<T, op> body; - CUDA_1D_KERNEL_LOOP(i, updates_size) { + GPU_1D_KERNEL_LOOP(i, updates_size) { int indices_i = i / update_block; int updates_i = i; int param_first_index = indices[indices_i]; @@ -97,7 +97,7 @@ __global__ void ScatterScalarOpCustomKernel(T* params, const T* update, Index synthesized_updates_size) { Index update_block = synthesized_updates_size / indices_size; ScatterOpKernelBody<T, op> body; - CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) { + GPU_1D_KERNEL_LOOP(i, synthesized_updates_size) { int indices_i = i / update_block; int param_first_index = indices[indices_i]; const T update_val = *update; @@ -126,8 +126,8 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { const Index first_dim_size = params.dimension(0); const Index indices_size = indices.size(); const Index updates_size = updates.size(); - GpuLaunchConfig config = GetCudaLaunchConfig(updates_size, d); - TF_CHECK_OK(CudaLaunchKernel( + GpuLaunchConfig config = GetGpuLaunchConfig(updates_size, d); + TF_CHECK_OK(GpuLaunchKernel( scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>, config.block_count, config.thread_per_block, 0, d.stream(), params.data(), updates.data(), indices.data(), first_dim_size, updates_size, indices_size)); @@ -147,8 +147,8 @@ struct ScatterScalarFunctor<GPUDevice, T, Index, op> { const Index first_dim_size = params.dimension(0); const Index indices_size = indices.size(); const Index synthesized_updates_size = indices_size * params.dimension(1); - GpuLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d); - TF_CHECK_OK(CudaLaunchKernel( + GpuLaunchConfig config = GetGpuLaunchConfig(synthesized_updates_size, d); + TF_CHECK_OK(GpuLaunchKernel( scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>, config.block_count, config.thread_per_block, 0, d.stream(), params.data(), update.data(), indices.data(), first_dim_size, @@ -160,6 +160,6 @@ struct ScatterScalarFunctor<GPUDevice, T, Index, op> { } // namespace functor } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_ diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index d307385e3a7..abf7cfde135 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -16,9 +16,9 @@ limitations under the License. // See docs in ../ops/state_ops.cc. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/scatter_nd_op.h" @@ -434,7 +434,7 @@ TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU); #undef REGISTER_SCATTER_ND_TENSOR_CPU // Registers GPU kernels. -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \ REGISTER_SCATTER_ND_ADD_SUB(type, GPU); @@ -509,7 +509,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU); #undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU #undef REGISTER_SCATTER_ND_TENSOR_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { // Check whether updates.shape = indices.shape[:batch_dim] + @@ -734,7 +734,7 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, } } // namespace functor -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ @@ -777,6 +777,6 @@ TF_CALL_complex128(DECLARE_GPU_SPECS); } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index 9152e71acb2..c0032be2e03 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -44,14 +44,14 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::ASSIGN> { template <typename T> struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { - CudaAtomicAdd(out, val); + GpuAtomicAdd(out, val); } }; template <typename T> struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { - CudaAtomicSub(out, val); + GpuAtomicSub(out, val); } }; @@ -63,8 +63,8 @@ struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( std::complex<T>* out, const std::complex<T>& val) { T* ptr = reinterpret_cast<T*>(out); - CudaAtomicAdd(ptr, val.real()); - CudaAtomicAdd(ptr, val.imag()); + GpuAtomicAdd(ptr, val.real()); + GpuAtomicAdd(ptr, val.imag()); } }; @@ -86,7 +86,7 @@ __global__ void ScatterNdOpKernel( const Index slice_size) { auto update = LeftUpdate<T, op>(); - CUDA_1D_KERNEL_LOOP(index, num_indices) { + GPU_1D_KERNEL_LOOP(index, num_indices) { Index i = 0; bool out_of_bounds = false; #pragma unroll @@ -135,9 +135,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> { } } - GpuLaunchConfig config = GetCudaLaunchConfig(Toutput.size(), d); + GpuLaunchConfig config = GetGpuLaunchConfig(Toutput.size(), d); - TF_CHECK_OK(CudaLaunchKernel(ScatterNdOpKernel<T, Index, op, IXDIM>, + TF_CHECK_OK(GpuLaunchKernel(ScatterNdOpKernel<T, Index, op, IXDIM>, config.block_count, config.thread_per_block, 0, d.stream(), Tindices.data(), Tupdates.data(), Toutput.data(), output_shape_prefix, @@ -181,4 +181,4 @@ TF_CALL_complex128(DECLARE_GPU_SPECS); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index ee3c5833470..81deaad5c95 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -279,7 +279,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); // Registers GPU kernels. -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ REGISTER_SCATTER_ARITHMETIC(type, GPU); @@ -291,7 +291,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Registers GPU kernels. #if TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc index d4defb85036..099604646fa 100644 --- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -53,4 +53,4 @@ DEFINE_GPU_SPECS(double); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM