diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index e2418dda59a..7e2a8b363c5 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -747,9 +747,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); #endif @@ -759,9 +757,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); #endif @@ -924,9 +920,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); #endif @@ -936,9 +930,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); #endif @@ -1411,9 +1403,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); #endif @@ -1423,9 +1413,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); #endif @@ -1525,9 +1513,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); #endif @@ -1537,9 +1523,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); #endif @@ -4275,9 +4259,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); #endif @@ -4287,9 +4269,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); #endif diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 1e53cfed777..92496e63e1a 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -118,12 +118,182 @@ struct ApplyGradientDescent { } }; +#if TENSORFLOW_USE_ROCM + +#include "rocm/include/hip/hip_complex.h" + +// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build +// process completes without errors,but the resulting executable ends up +// unusable (throwing errors "no device code available for function" for +/// completely unrelated kernels.) +// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does +// not provide sqrt for complex. +// We have no choice but to implement sqrt and rsqrt by hand +template +__device__ T impl_sqrt(T x) { + return sqrt(x); +} +template +__device__ T impl_rsqrt(T x) { + return rsqrt(x); +} +template <> +__device__ Eigen::half impl_sqrt(Eigen::half x) { + return __float2half(sqrt(__half2float(x))); +} +template <> +__device__ Eigen::half impl_rsqrt(Eigen::half x) { + return __float2half(rsqrt(__half2float(x))); +} + +template +__device__ std::complex impl_sqrt(std::complex x) { + T re = x.real(), im = x.imag(); + T mod_x = sqrt(re * re + im * im); + const T root2 = 0.7071067811865475; + // We pick the root with the same sign of the imaginary component as + // the input. + T root[2] = {T(sqrt(mod_x + re) * root2), + T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))}; + // hcc/clang is really weird with its support of complex in device code; + // for some reason it does not permit a 2-argument constructor + return *(reinterpret_cast*>(&root)); +} + +template +__device__ T rsqrt_helper(T x) { + return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x; +} + +template +__device__ std::complex impl_rsqrt(std::complex x) { + T re = x.real(), im = x.imag(); + T r = rsqrt(re * re + im * im); + T ar2 = re * r * r; + const T root2 = 0.7071067811865475; + T root[2]; + // With float, calculating 1+re*r and 1-re*r may result in excessive errors + // due to subtraction of two close values. We have to get fancy + root[0] = sqrt(r * ((std::is_same::value && re * r < -0.98) + ? rsqrt_helper(im * im * r * r) + : 1 + re * r)) * + root2; + root[1] = sqrt(r * ((std::is_same::value && re * r > 0.98) + ? rsqrt_helper(im * im * r * r) + : 1 - re * r)) * + root2 * (im >= 0 ? -1. : 1.); + return *(reinterpret_cast*>(&root)); +} + +template +__global__ void ApplyAdagradKernel(GpuLaunchConfig cfg, T* var, T* accum, + const T* lr, const T* grad, + bool update_slots) { + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + if (update_slots) accum[i] += grad[i] * grad[i]; + var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]); + } +} + +template +__global__ void ApplyAdagradV2Kernel(GpuLaunchConfig cfg, T* var, T* accum, + const T* lr, const T* epsilon, + const T* grad, bool update_slots) { + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + if (update_slots) accum[i] += grad[i] * grad[i]; + T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]); + var[i] -= lr[0] * update; + } +} + +template +__global__ void ApplyAdadeltaKernel(GpuLaunchConfig cfg, T* var, T* accum, + T* accum_update, const T* plr, + const T* prho, const T* peps, + const T* grad) { + T rho = prho[0]; + T eps = peps[0]; + T lr = plr[0]; + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho); + T update = + impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps); + var[i] -= update * lr; + accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho); + } +} + +template +__global__ void ApplyRMSPropKernel(GpuLaunchConfig cfg, T* var, T* ms, T* mom, + const T* plr, const T* prho, + const T* pmomentum, const T* peps, + const T* grad) { + T rho = prho[0]; + T eps = peps[0]; + T lr = plr[0]; + T momentum = pmomentum[0]; + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]); + mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]); + var[i] -= mom[i]; + } +} + +template +__global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg, + T* ms, T* mom, const T* plr, + const T* prho, const T* pmomentum, + const T* peps, const T* grad) { + T rho = prho[0]; + T eps = peps[0]; + T lr = plr[0]; + T momentum = pmomentum[0]; + T one_minus_rho = T(1.0) - rho; + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]); + mg[i] += one_minus_rho * (grad[i] - mg[i]); + T denom = (ms[i] - mg[i] * mg[i]) + eps; + mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom); + var[i] -= mom[i]; + } +} + +namespace kernel_forward { +bool to_pointers(bool x) { return x; } +template +typename T::PointerType to_pointers(T& x) { + return x.data(); +} +template +typename T::ConstPointerType to_pointers(const T& x) { + return x.data(); +} + +template +void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var, + CallerArgs... args) { + int32 data_dim = var.dimension(0); + auto config = GetGpuLaunchConfig(data_dim, d); + TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block, + 0, d.stream(), config, var.data(), + to_pointers(args)...)); +} +}; // namespace kernel_forward + +using kernel_forward::wrap_kernel_call; + +#endif + template struct ApplyAdagrad { void operator()(const GPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad, bool update_slots) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyAdagradKernel, d, var, accum, lr, grad, + update_slots); +#else if (update_slots) { accum.device(d) += grad.square(); } @@ -131,6 +301,7 @@ struct ApplyAdagrad { bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt(); +#endif } }; @@ -141,6 +312,10 @@ struct ApplyAdagradV2 { typename TTypes::ConstScalar lr, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad, bool update_slots) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyAdagradV2Kernel, d, var, accum, lr, epsilon, grad, + update_slots); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -150,9 +325,9 @@ struct ApplyAdagradV2 { const auto update = grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast)); var.device(d) -= lr.reshape(single).broadcast(bcast) * update; +#endif } }; - template struct ApplyAdadelta { void operator()(const GPUDevice& d, typename TTypes::Flat var, @@ -162,6 +337,10 @@ struct ApplyAdadelta { typename TTypes::ConstScalar rho, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyAdadeltaKernel, d, var, accum, accum_update, lr, + rho, epsilon, grad); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -177,6 +356,7 @@ struct ApplyAdadelta { accum_update * rho.reshape(single).broadcast(bcast) + update.square() * (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast)); +#endif } }; @@ -489,6 +669,10 @@ struct ApplyRMSProp { typename TTypes::ConstScalar momentum, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyRMSPropKernel, d, var, ms, mom, lr, rho, momentum, + epsilon, grad); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -501,6 +685,7 @@ struct ApplyRMSProp { lr.reshape(single).broadcast(bcast) * grad / ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt()); var.device(d) -= mom; +#endif } }; @@ -514,6 +699,10 @@ struct ApplyCenteredRMSProp { typename TTypes::ConstScalar momentum, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyCenteredRMSPropKernel, d, var, mg, ms, mom, lr, + rho, momentum, epsilon, grad); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -526,6 +715,7 @@ struct ApplyCenteredRMSProp { mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) + lr.reshape(single).broadcast(bcast) * grad / denom.sqrt(); var.device(d) -= mom; +#endif } }; @@ -599,9 +789,8 @@ struct ApplyPowerSign { template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support + // complex sqrt template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; #endif @@ -609,9 +798,8 @@ template struct functor::ApplyGradientDescent; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support + // complex sqrt template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; #endif @@ -619,9 +807,8 @@ template struct functor::ApplyAdagrad; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support + // complex sqrt template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; #endif @@ -629,9 +816,8 @@ template struct functor::ApplyAdagradV2; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support + // complex sqrt template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; #endif @@ -710,9 +896,8 @@ template struct functor::ApplyAdaMax; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support + // complex sqrt template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; #endif @@ -720,9 +905,8 @@ template struct functor::ApplyRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support + // complex sqrt template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; #endif diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py index c5a7e0414ce..0827d367110 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta_test.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()): +if not test_util.IsBuiltWithNvcc(): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py index 8c69a19171a..a4b331e622c 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()): +if not test_util.IsBuiltWithNvcc(): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py index 66da72ba0ac..612d8ba0159 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py @@ -41,7 +41,7 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()): +if not test_util.IsBuiltWithNvcc(): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] _TEST_PARAM_VALUES = [