Adding ROCm support for scatter ops
This commit is contained in:
parent
dac4bd7750
commit
124daa3728
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/scatter_functor.h"
|
#include "tensorflow/core/kernels/scatter_functor.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -68,4 +68,4 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
|
|||||||
|
|
||||||
#include "tensorflow/core/kernels/scatter_functor.h"
|
#include "tensorflow/core/kernels/scatter_functor.h"
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -54,4 +54,4 @@ DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
|
#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -41,32 +41,32 @@ struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> {
|
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> {
|
||||||
__device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); }
|
__device__ void operator()(T* dest, T src) const { GpuAtomicAdd(dest, src); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> {
|
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> {
|
||||||
__device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); }
|
__device__ void operator()(T* dest, T src) const { GpuAtomicSub(dest, src); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> {
|
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> {
|
||||||
__device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); }
|
__device__ void operator()(T* dest, T src) const { GpuAtomicMul(dest, src); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> {
|
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> {
|
||||||
__device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); }
|
__device__ void operator()(T* dest, T src) const { GpuAtomicDiv(dest, src); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> {
|
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> {
|
||||||
__device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); }
|
__device__ void operator()(T* dest, T src) const { GpuAtomicMin(dest, src); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
|
struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
|
||||||
__device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); }
|
__device__ void operator()(T* dest, T src) const { GpuAtomicMax(dest, src); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename Index, scatter_op::UpdateOp op>
|
template <typename T, typename Index, scatter_op::UpdateOp op>
|
||||||
@ -76,7 +76,7 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
|
|||||||
Index indices_size) {
|
Index indices_size) {
|
||||||
Index update_block = updates_size / indices_size;
|
Index update_block = updates_size / indices_size;
|
||||||
ScatterOpKernelBody<T, op> body;
|
ScatterOpKernelBody<T, op> body;
|
||||||
CUDA_1D_KERNEL_LOOP(i, updates_size) {
|
GPU_1D_KERNEL_LOOP(i, updates_size) {
|
||||||
int indices_i = i / update_block;
|
int indices_i = i / update_block;
|
||||||
int updates_i = i;
|
int updates_i = i;
|
||||||
int param_first_index = indices[indices_i];
|
int param_first_index = indices[indices_i];
|
||||||
@ -97,7 +97,7 @@ __global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
|
|||||||
Index synthesized_updates_size) {
|
Index synthesized_updates_size) {
|
||||||
Index update_block = synthesized_updates_size / indices_size;
|
Index update_block = synthesized_updates_size / indices_size;
|
||||||
ScatterOpKernelBody<T, op> body;
|
ScatterOpKernelBody<T, op> body;
|
||||||
CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) {
|
GPU_1D_KERNEL_LOOP(i, synthesized_updates_size) {
|
||||||
int indices_i = i / update_block;
|
int indices_i = i / update_block;
|
||||||
int param_first_index = indices[indices_i];
|
int param_first_index = indices[indices_i];
|
||||||
const T update_val = *update;
|
const T update_val = *update;
|
||||||
@ -126,8 +126,8 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
|
|||||||
const Index first_dim_size = params.dimension(0);
|
const Index first_dim_size = params.dimension(0);
|
||||||
const Index indices_size = indices.size();
|
const Index indices_size = indices.size();
|
||||||
const Index updates_size = updates.size();
|
const Index updates_size = updates.size();
|
||||||
GpuLaunchConfig config = GetCudaLaunchConfig(updates_size, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(updates_size, d);
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>, config.block_count,
|
scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>, config.block_count,
|
||||||
config.thread_per_block, 0, d.stream(), params.data(), updates.data(),
|
config.thread_per_block, 0, d.stream(), params.data(), updates.data(),
|
||||||
indices.data(), first_dim_size, updates_size, indices_size));
|
indices.data(), first_dim_size, updates_size, indices_size));
|
||||||
@ -147,8 +147,8 @@ struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
|
|||||||
const Index first_dim_size = params.dimension(0);
|
const Index first_dim_size = params.dimension(0);
|
||||||
const Index indices_size = indices.size();
|
const Index indices_size = indices.size();
|
||||||
const Index synthesized_updates_size = indices_size * params.dimension(1);
|
const Index synthesized_updates_size = indices_size * params.dimension(1);
|
||||||
GpuLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(synthesized_updates_size, d);
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>,
|
scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>,
|
||||||
config.block_count, config.thread_per_block, 0, d.stream(),
|
config.block_count, config.thread_per_block, 0, d.stream(),
|
||||||
params.data(), update.data(), indices.data(), first_dim_size,
|
params.data(), update.data(), indices.data(), first_dim_size,
|
||||||
@ -160,6 +160,6 @@ struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
|
#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
// See docs in ../ops/state_ops.cc.
|
// See docs in ../ops/state_ops.cc.
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||||
|
|
||||||
@ -434,7 +434,7 @@ TF_CALL_bool(REGISTER_SCATTER_ND_TENSOR_UPDATE_CPU);
|
|||||||
#undef REGISTER_SCATTER_ND_TENSOR_CPU
|
#undef REGISTER_SCATTER_ND_TENSOR_CPU
|
||||||
|
|
||||||
// Registers GPU kernels.
|
// Registers GPU kernels.
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
|
#define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
|
||||||
REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
|
REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
|
||||||
@ -509,7 +509,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
|
|||||||
#undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
|
#undef REGISTER_SCATTER_ND_TENSOR_SUB_GPU
|
||||||
#undef REGISTER_SCATTER_ND_TENSOR_GPU
|
#undef REGISTER_SCATTER_ND_TENSOR_GPU
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
// Check whether updates.shape = indices.shape[:batch_dim] +
|
// Check whether updates.shape = indices.shape[:batch_dim] +
|
||||||
@ -734,7 +734,7 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
|
|||||||
}
|
}
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
|
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
|
||||||
@ -777,6 +777,6 @@ TF_CALL_complex128(DECLARE_GPU_SPECS);
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -44,14 +44,14 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::ASSIGN> {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> {
|
struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> {
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||||
CudaAtomicAdd(out, val);
|
GpuAtomicAdd(out, val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
|
struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||||
CudaAtomicSub(out, val);
|
GpuAtomicSub(out, val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -63,8 +63,8 @@ struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
|
|||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
|
||||||
std::complex<T>* out, const std::complex<T>& val) {
|
std::complex<T>* out, const std::complex<T>& val) {
|
||||||
T* ptr = reinterpret_cast<T*>(out);
|
T* ptr = reinterpret_cast<T*>(out);
|
||||||
CudaAtomicAdd(ptr, val.real());
|
GpuAtomicAdd(ptr, val.real());
|
||||||
CudaAtomicAdd(ptr, val.imag());
|
GpuAtomicAdd(ptr, val.imag());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ __global__ void ScatterNdOpKernel(
|
|||||||
const Index slice_size) {
|
const Index slice_size) {
|
||||||
auto update = LeftUpdate<T, op>();
|
auto update = LeftUpdate<T, op>();
|
||||||
|
|
||||||
CUDA_1D_KERNEL_LOOP(index, num_indices) {
|
GPU_1D_KERNEL_LOOP(index, num_indices) {
|
||||||
Index i = 0;
|
Index i = 0;
|
||||||
bool out_of_bounds = false;
|
bool out_of_bounds = false;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -135,9 +135,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuLaunchConfig config = GetCudaLaunchConfig(Toutput.size(), d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(Toutput.size(), d);
|
||||||
|
|
||||||
TF_CHECK_OK(CudaLaunchKernel(ScatterNdOpKernel<T, Index, op, IXDIM>,
|
TF_CHECK_OK(GpuLaunchKernel(ScatterNdOpKernel<T, Index, op, IXDIM>,
|
||||||
config.block_count, config.thread_per_block, 0,
|
config.block_count, config.thread_per_block, 0,
|
||||||
d.stream(), Tindices.data(), Tupdates.data(),
|
d.stream(), Tindices.data(), Tupdates.data(),
|
||||||
Toutput.data(), output_shape_prefix,
|
Toutput.data(), output_shape_prefix,
|
||||||
@ -181,4 +181,4 @@ TF_CALL_complex128(DECLARE_GPU_SPECS);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -279,7 +279,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
|
|||||||
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
|
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
|
||||||
|
|
||||||
// Registers GPU kernels.
|
// Registers GPU kernels.
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
|
#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
|
||||||
REGISTER_SCATTER_ARITHMETIC(type, GPU);
|
REGISTER_SCATTER_ARITHMETIC(type, GPU);
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
|
|||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
|
||||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
|
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// Registers GPU kernels.
|
// Registers GPU kernels.
|
||||||
#if TENSORFLOW_USE_SYCL
|
#if TENSORFLOW_USE_SYCL
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -53,4 +53,4 @@ DEFINE_GPU_SPECS(double);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
Reference in New Issue
Block a user