diff --git a/tensorflow/core/util/gpu_device_functions.h b/tensorflow/core/util/gpu_device_functions.h index bdfe528bcdf..5381c054d32 100644 --- a/tensorflow/core/util/gpu_device_functions.h +++ b/tensorflow/core/util/gpu_device_functions.h @@ -606,7 +606,7 @@ __device__ double GpuAtomicCasHelper(double* ptr, F accumulate) { // HIP has a bug in the implementation of __longlong_as_double // So workaround it by using reinterpret_cast. uint64_t result = - GpuAtomicCasHelper(reinterpret_cast(ptr), + GpuAtomicCasHelper(reinterpret_cast(ptr), [accumulate](tensorflow::uint64 a) { return __double_as_longlong( accumulate(*(reinterpret_cast(&a)))); @@ -614,7 +614,7 @@ __device__ double GpuAtomicCasHelper(double* ptr, F accumulate) { return *(reinterpret_cast(&result)); #else return __longlong_as_double(GpuAtomicCasHelper( - reinterpret_cast(ptr), + reinterpret_cast(ptr), [accumulate](tensorflow::uint64 a) { return __double_as_longlong(accumulate(__longlong_as_double(a))); })); @@ -676,6 +676,38 @@ template using ToTypeIfConvertible = typename std::enable_if::value, To>::type; +template +struct CudaSupportedTypeImpl { + using type = T; +}; + +template <> +struct CudaSupportedTypeImpl { + using type = unsigned long long; +}; + +template <> +struct CudaSupportedTypeImpl { + using type = + typename std::conditional::type; +}; + +template <> +struct CudaSupportedTypeImpl { + // This cast should be safe since module-2 addition should work fine. However, + // signed overflow is not handled correctly since it's undefined behavior. + using type = typename CudaSupportedTypeImpl::type; +}; + +template +using CudaSupportedType = typename CudaSupportedTypeImpl::type; + +template +__device__ CudaSupportedType* ToCudaSupportedPtr(T* ptr) { + return reinterpret_cast*>(ptr); +} + } // namespace detail // CUDA provides atomic ops, but not for all types. We provide wrappers @@ -683,13 +715,7 @@ using ToTypeIfConvertible = template __device__ detail::ToTypeIfConvertible GpuAtomicAdd(T* ptr, U value) { - return atomicAdd(ptr, value); -} - -__device__ inline int64 GpuAtomicAdd(int64* ptr, int64 value) { - // This cast should be safe since module-2 addition should work fine. However, - // signed overflow is not handled correctly since it's undefined behavior. - return atomicAdd(reinterpret_cast(ptr), static_cast(value)); + return atomicAdd(detail::ToCudaSupportedPtr(ptr), value); } __device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr, @@ -765,7 +791,7 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicSub, CudaAtomicSub); // GpuAtomicMax template __device__ detail::ToTypeIfConvertible GpuAtomicMax(T* ptr, U value) { - return atomicMax(ptr, value); + return atomicMax(detail::ToCudaSupportedPtr(ptr), value); } #if TENSORFLOW_USE_ROCM @@ -817,11 +843,12 @@ __device__ inline Eigen::half GpuAtomicMax(Eigen::half* ptr, __device__ inline tensorflow::uint64 GpuAtomicMax(tensorflow::uint64* ptr, tensorflow::uint64 value) { return detail::GpuAtomicCasHelper( - ptr, [value](tensorflow::uint64 a) { return max(a, value); }); + detail::ToCudaSupportedPtr(ptr), + [value](tensorflow::uint64 a) { return max(a, value); }); } __device__ inline int64 GpuAtomicMax(int64* ptr, int64 value) { - return detail::GpuAtomicCasHelper(ptr, + return detail::GpuAtomicCasHelper(detail::ToCudaSupportedPtr(ptr), [value](int64 a) { return max(a, value); }); } #endif @@ -830,7 +857,7 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMax, CudaAtomicMax); // GpuAtomicMin template __device__ detail::ToTypeIfConvertible GpuAtomicMin(T* ptr, U value) { - return atomicMin(ptr, value); + return atomicMin(detail::ToCudaSupportedPtr(ptr), value); } #if TENSORFLOW_USE_ROCM @@ -882,11 +909,12 @@ __device__ inline Eigen::half GpuAtomicMin(Eigen::half* ptr, __device__ inline tensorflow::uint64 GpuAtomicMin(tensorflow::uint64* ptr, tensorflow::uint64 value) { return detail::GpuAtomicCasHelper( - ptr, [value](tensorflow::uint64 a) { return min(a, value); }); + detail::ToCudaSupportedPtr(ptr), + [value](tensorflow::uint64 a) { return min(a, value); }); } __device__ inline int64 GpuAtomicMin(int64* ptr, int64 value) { - return detail::GpuAtomicCasHelper(ptr, + return detail::GpuAtomicCasHelper(detail::ToCudaSupportedPtr(ptr), [value](int64 a) { return min(a, value); }); } #endif