From 3693256ba03f67d38fe33daa495b296d45539076 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Sun, 22 Dec 2019 17:17:32 -0800 Subject: [PATCH 1/3] ROCm-optimized training apply ops --- tensorflow/core/kernels/training_ops.cc | 50 ++-- .../core/kernels/training_ops_gpu.cu.cc | 234 +++++++++++++++--- .../keras/optimizer_v2/adadelta_test.py | 3 +- .../python/keras/optimizer_v2/adagrad_test.py | 3 +- .../python/keras/optimizer_v2/rmsprop_test.py | 3 +- 5 files changed, 226 insertions(+), 67 deletions(-) diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 467087b7864..147c39bd4d7 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -670,9 +670,8 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -684,9 +683,8 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -853,9 +851,8 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -867,9 +864,8 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1348,9 +1344,8 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1362,9 +1357,8 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1468,9 +1462,8 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1482,9 +1475,8 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -4183,9 +4175,8 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -4197,9 +4188,8 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined( \ + TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 8b7f5dc2e40..f9bce685828 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -118,12 +118,184 @@ struct ApplyGradientDescent { } }; +#if TENSORFLOW_USE_ROCM + +#include "rocm/include/hip/hip_complex.h" + +// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build +// process completes without errors,but the resulting executable ends up +// unusable (throwing errors "no device code available for function" for +/// completely unrelated kernels.) +// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does +// not provide sqrt for complex. +// We have no choice but to implement sqrt and rsqrt by hand +template +__device__ T impl_sqrt(T x) { + return sqrt(x); +} +template +__device__ T impl_rsqrt(T x) { + return rsqrt(x); +} +template <> +__device__ Eigen::half impl_sqrt(Eigen::half x) { + return __float2half(sqrt(__half2float(x))); +} +template <> +__device__ Eigen::half impl_rsqrt(Eigen::half x) { + return __float2half(rsqrt(__half2float(x))); +} + +template +__device__ std::complex impl_sqrt(std::complex x) { + T re = x.real(), im = x.imag(); + T mod_x = sqrt(re * re + im * im); + const T root2 = 0.7071067811865475; + // we pick the root with the same sign of the imaginary component as the input + T root[2] = {T(sqrt(mod_x + re) * root2), + T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))}; + // hcc/clang is really weird with its support of complex in device code; + // for some reason it does not permit a 2-argument constructor + return *(reinterpret_cast*>(&root)); +} + +template +__device__ T rsqrt_helper(T x) { + return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x; +} + +template +__device__ std::complex impl_rsqrt(std::complex x) { + T re = x.real(), im = x.imag(); + T r = rsqrt(re * re + im * im); + T ar2 = re * r * r; + const T root2 = 0.7071067811865475; + T root[2]; + // With float, calculating 1+re*r and 1-re*r may result in excessive errors + // due to subtraction of two close values. We have to get fancy + root[0] = sqrt(r * ((std::is_same::value && re * r < -0.98) + ? rsqrt_helper(im * im * r * r) + : 1 + re * r)) * + root2; + root[1] = sqrt(r * ((std::is_same::value && re * r > 0.98) + ? rsqrt_helper(im * im * r * r) + : 1 - re * r)) * + root2 * (im >= 0 ? -1. : 1.); + return *(reinterpret_cast*>(&root)); +} + +template +__global__ void ApplyAdagradKernel(GpuLaunchConfig cfg, T* var, T* accum, + const T* lr, const T* grad, + bool update_slots) { + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + if (update_slots) accum[i] += grad[i] * grad[i]; + var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]); + } +} + +template +__global__ void ApplyAdagradV2Kernel(GpuLaunchConfig cfg, T* var, T* accum, + const T* lr, const T* epsilon, + const T* grad, bool update_slots) { + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + if (update_slots) accum[i] += grad[i] * grad[i]; + T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]); + var[i] -= lr[0] * update; + } +} + +template +__global__ void ApplyAdadeltaKernel(GpuLaunchConfig cfg, T* var, T* accum, + T* accum_update, const T* plr, + const T* prho, const T* peps, + const T* grad) { + T rho = prho[0]; + T eps = peps[0]; + T lr = plr[0]; + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho); + T update = + impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps); + var[i] -= update * lr; + accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho); + } +} + +template +__global__ void ApplyRMSPropKernel(GpuLaunchConfig cfg, T* var, T* ms, T* mom, + const T* plr, const T* prho, + const T* pmomentum, const T* peps, + const T* grad) { + T rho = prho[0]; + T eps = peps[0]; + T lr = plr[0]; + T momentum = pmomentum[0]; + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]); + mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]); + var[i] -= mom[i]; + } +} + +template +__global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg, + T* ms, T* mom, const T* plr, + const T* prho, const T* pmomentum, + const T* peps, const T* grad) { + T rho = prho[0]; + T eps = peps[0]; + T lr = plr[0]; + T momentum = pmomentum[0]; + T one_minus_rho = T(1.0) - rho; + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]); + mg[i] += one_minus_rho * (grad[i] - mg[i]); + T denom = (ms[i] - mg[i] * mg[i]) + eps; + mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom); + var[i] -= mom[i]; + } +} +#endif + +#if TENSORFLOW_USE_ROCM + +namespace kernel_forward { +bool to_pointers(bool x) { return x; } +template +typename T::PointerType to_pointers(T& x) { + return x.data(); +} +template +typename T::ConstPointerType to_pointers(const T& x) { + return x.data(); +} + +template +void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var, + CallerArgs... args) { + int32 data_dim = var.dimension(0); + auto config = GetGpuLaunchConfig(data_dim, d); + TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block, + 0, d.stream(), config, var.data(), + to_pointers(args)...)); +} +}; // namespace kernel_forward + +using kernel_forward::wrap_kernel_call; + +#endif + template struct ApplyAdagrad { void operator()(const GPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad, bool update_slots) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyAdagradKernel, d, var, accum, lr, grad, + update_slots); +#else if (update_slots) { accum.device(d) += grad.square(); } @@ -131,6 +303,7 @@ struct ApplyAdagrad { bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt(); +#endif } }; @@ -141,6 +314,10 @@ struct ApplyAdagradV2 { typename TTypes::ConstScalar lr, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad, bool update_slots) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyAdagradV2Kernel, d, var, accum, lr, epsilon, grad, + update_slots); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -150,9 +327,9 @@ struct ApplyAdagradV2 { const auto update = grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast)); var.device(d) -= lr.reshape(single).broadcast(bcast) * update; +#endif } }; - template struct ApplyAdadelta { void operator()(const GPUDevice& d, typename TTypes::Flat var, @@ -162,6 +339,10 @@ struct ApplyAdadelta { typename TTypes::ConstScalar rho, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyAdadeltaKernel, d, var, accum, accum_update, lr, + rho, epsilon, grad); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -177,6 +358,7 @@ struct ApplyAdadelta { accum_update * rho.reshape(single).broadcast(bcast) + update.square() * (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast)); +#endif } }; @@ -414,6 +596,10 @@ struct ApplyRMSProp { typename TTypes::ConstScalar momentum, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyRMSPropKernel, d, var, ms, mom, lr, rho, momentum, + epsilon, grad); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -426,6 +612,7 @@ struct ApplyRMSProp { lr.reshape(single).broadcast(bcast) * grad / ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt()); var.device(d) -= mom; +#endif } }; @@ -439,6 +626,10 @@ struct ApplyCenteredRMSProp { typename TTypes::ConstScalar momentum, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyCenteredRMSPropKernel, d, var, mg, ms, mom, lr, + rho, momentum, epsilon, grad); +#else Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -451,6 +642,7 @@ struct ApplyCenteredRMSProp { mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) + lr.reshape(single).broadcast(bcast) * grad / denom.sqrt(); var.device(d) -= mom; +#endif } }; @@ -524,9 +716,7 @@ struct ApplyPowerSign { template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) #ifndef PLATFORM_WINDOWS template struct functor::ApplyGradientDescent; template struct functor::ApplyGradientDescent; @@ -536,9 +726,7 @@ template struct functor::ApplyGradientDescent; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdagrad; template struct functor::ApplyAdagrad; @@ -548,9 +736,7 @@ template struct functor::ApplyAdagrad; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; @@ -560,9 +746,7 @@ template struct functor::ApplyAdagradV2; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; @@ -580,9 +764,7 @@ template struct functor::ApplyFtrlV2; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM) #ifndef PLATFORM_WINDOWS template struct functor::ApplyMomentum; template struct functor::ApplyMomentum; @@ -592,9 +774,7 @@ template struct functor::ApplyMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM) #ifndef PLATFORM_WINDOWS template struct functor::ApplyKerasMomentum; template struct functor::ApplyKerasMomentum; @@ -609,9 +789,7 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM) #ifndef PLATFORM_WINDOWS template struct functor::SparseApplyKerasMomentum; template struct functor::SparseApplyKerasMomentum; @@ -623,9 +801,7 @@ template struct functor::SparseApplyKerasMomentum; template struct functor::ApplyAdam; template struct functor::ApplyAdam; template struct functor::ApplyAdam; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM) #ifndef PLATFORM_WINDOWS template struct functor::ApplyAdam; template struct functor::ApplyAdam; @@ -643,9 +819,7 @@ template struct functor::ApplyAdaMax; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) #ifndef PLATFORM_WINDOWS template struct functor::ApplyRMSProp; template struct functor::ApplyRMSProp; @@ -655,9 +829,7 @@ template struct functor::ApplyRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; -#if !defined(TENSORFLOW_USE_NVCC) && \ - !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support - // complex sqrt +#if !defined(TENSORFLOW_USE_NVCC) #ifndef PLATFORM_WINDOWS template struct functor::ApplyCenteredRMSProp; template struct functor::ApplyCenteredRMSProp; diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py index 76f9f1cfb90..d516e32672a 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta_test.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py @@ -35,8 +35,7 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and - not test.is_built_with_rocm()): +if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows"): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py index 9cbcd27b5d8..685a4bea388 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py @@ -38,8 +38,7 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and - not test.is_built_with_rocm()): +if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows"): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py index a3480d62f21..2e3a907a6f2 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py @@ -41,8 +41,7 @@ from tensorflow.python.platform import test _DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64] # TODO(b/143684500): Eigen to support complex sqrt -if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and - not test.is_built_with_rocm()): +if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows"): _DATA_TYPES += [dtypes.complex64, dtypes.complex128] _TEST_PARAM_VALUES = [ From fb91bf645837c5755daedfe537a1acf8f5aa8f65 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Sat, 1 Feb 2020 01:48:46 -0800 Subject: [PATCH 2/3] Formatting --- tensorflow/core/kernels/training_ops.cc | 30 +++++++++---------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 147c39bd4d7..efef69d2e54 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -670,8 +670,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -683,8 +682,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -851,8 +849,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -864,8 +861,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1344,8 +1340,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1357,8 +1352,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -1462,8 +1456,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -1475,8 +1468,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); @@ -4175,8 +4167,7 @@ namespace functor { DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); @@ -4188,8 +4179,7 @@ DECLARE_GPU_SPEC(complex128); REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); -#if !defined( \ - TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt +#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt #ifndef PLATFORM_WINDOWS REGISTER_KERNELS(GPU, complex64); REGISTER_KERNELS(GPU, complex128); From 83d44dd58109759235bd7ea3ac601de03be92494 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Mon, 6 Apr 2020 20:28:27 -0700 Subject: [PATCH 3/3] Reviewer requested changes --- tensorflow/core/kernels/training_ops_gpu.cu.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 875256f3f8d..d33dc5aa8cc 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -151,7 +151,8 @@ __device__ std::complex impl_sqrt(std::complex x) { T re = x.real(), im = x.imag(); T mod_x = sqrt(re * re + im * im); const T root2 = 0.7071067811865475; - // we pick the root with the same sign of the imaginary component as the input + // We pick the root with the same sign of the imaginary component as + // the input. T root[2] = {T(sqrt(mod_x + re) * root2), T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))}; // hcc/clang is really weird with its support of complex in device code; @@ -256,9 +257,6 @@ __global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg, var[i] -= mom[i]; } } -#endif - -#if TENSORFLOW_USE_ROCM namespace kernel_forward { bool to_pointers(bool x) { return x; }