Move/enable helper functions in training_ops_gpu

- Moves the ROCm helper functions to the top of the file and enables
  them for use with CUDA as well.
- No functional change.
This commit is contained in:
Ben Barsdell 2020-12-18 10:43:43 +11:00
parent 483d6fe292
commit 97024506de

View File

@ -27,6 +27,89 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_complex.h"
#endif // TENSORFLOW_USE_ROCM
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
// process completes without errors,but the resulting executable ends up
// unusable (throwing errors "no device code available for function" for
/// completely unrelated kernels.)
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
// not provide sqrt for complex.
// We have no choice but to implement sqrt and rsqrt by hand
template <typename T>
__device__ T impl_sqrt(T x) {
return sqrt(x);
}
template <typename T>
__device__ T impl_rsqrt(T x) {
return rsqrt(x);
}
template <>
__device__ Eigen::half impl_sqrt(Eigen::half x) {
return __float2half(sqrt(__half2float(x)));
}
template <>
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
return __float2half(rsqrt(__half2float(x)));
}
template <class T>
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T mod_x = sqrt(re * re + im * im);
const T root2 = 0.7071067811865475;
// We pick the root with the same sign of the imaginary component as
// the input.
T root[2] = {T(sqrt(mod_x + re) * root2),
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <class T>
__device__ T rsqrt_helper(T x) {
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
}
template <class T>
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T r = rsqrt(re * re + im * im);
T ar2 = re * r * r;
const T root2 = 0.7071067811865475;
T root[2];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 + re * r))) *
root2;
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 - re * r))) *
root2 * (im >= 0 ? -1. : 1.);
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <typename T>
__device__ T impl_fabs(T x) {
return fabs(x);
}
template <>
__device__ Eigen::half impl_fabs(Eigen::half x) {
return __float2half(fabs(__half2float(x)));
}
template <typename T>
__device__ T impl_sign(T x) {
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
}
template <typename T, typename Tindex, bool has_l2_shrinkage>
__global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
const T* l1, const T* l2,
@ -183,87 +266,6 @@ struct ApplyGradientDescent<GPUDevice, T> {
}
};
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_complex.h"
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
// process completes without errors,but the resulting executable ends up
// unusable (throwing errors "no device code available for function" for
/// completely unrelated kernels.)
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
// not provide sqrt for complex.
// We have no choice but to implement sqrt and rsqrt by hand
template <typename T>
__device__ T impl_sqrt(T x) {
return sqrt(x);
}
template <typename T>
__device__ T impl_rsqrt(T x) {
return rsqrt(x);
}
template <>
__device__ Eigen::half impl_sqrt(Eigen::half x) {
return __float2half(sqrt(__half2float(x)));
}
template <>
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
return __float2half(rsqrt(__half2float(x)));
}
template <class T>
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T mod_x = sqrt(re * re + im * im);
const T root2 = 0.7071067811865475;
// We pick the root with the same sign of the imaginary component as
// the input.
T root[2] = {T(sqrt(mod_x + re) * root2),
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <class T>
__device__ T rsqrt_helper(T x) {
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
}
template <class T>
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T r = rsqrt(re * re + im * im);
T ar2 = re * r * r;
const T root2 = 0.7071067811865475;
T root[2];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 + re * r))) *
root2;
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 - re * r))) *
root2 * (im >= 0 ? -1. : 1.);
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <typename T>
__device__ T impl_fabs(T x) {
return fabs(x);
}
template <>
__device__ Eigen::half impl_fabs(Eigen::half x) {
return __float2half(fabs(__half2float(x)));
}
template <typename T>
__device__ T impl_sign(T x) {
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
}
template <typename T>
__global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
T* var, T* accum,
@ -374,8 +376,6 @@ void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
using kernel_forward::wrap_kernel_call;
#endif
template <typename T>
struct ApplyAdagrad<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,