Merge pull request #35924 from ROCmSoftwarePlatform:google_upstream_training_ops

PiperOrigin-RevId: 306484877
Change-Id: Ibcf72e35c40a8a5d3570168d8e48aca77b86e36f
This commit is contained in:
TensorFlower Gardener 2020-04-14 12:00:17 -07:00
commit 3b29db9951
5 changed files with 216 additions and 52 deletions

View File

@ -747,9 +747,7 @@ namespace functor {
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
@ -759,9 +757,7 @@ DECLARE_GPU_SPEC(complex128);
REGISTER_KERNELS(GPU, Eigen::half);
REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
REGISTER_KERNELS(GPU, complex64);
REGISTER_KERNELS(GPU, complex128);
#endif
@ -924,9 +920,7 @@ namespace functor {
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
@ -936,9 +930,7 @@ DECLARE_GPU_SPEC(complex128);
REGISTER_KERNELS(GPU, Eigen::half);
REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
REGISTER_KERNELS(GPU, complex64);
REGISTER_KERNELS(GPU, complex128);
#endif
@ -1411,9 +1403,7 @@ namespace functor {
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
@ -1423,9 +1413,7 @@ DECLARE_GPU_SPEC(complex128);
REGISTER_KERNELS(GPU, Eigen::half);
REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
REGISTER_KERNELS(GPU, complex64);
REGISTER_KERNELS(GPU, complex128);
#endif
@ -1525,9 +1513,7 @@ namespace functor {
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
@ -1537,9 +1523,7 @@ DECLARE_GPU_SPEC(complex128);
REGISTER_KERNELS(GPU, Eigen::half);
REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
REGISTER_KERNELS(GPU, complex64);
REGISTER_KERNELS(GPU, complex128);
#endif
@ -4275,9 +4259,7 @@ namespace functor {
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
@ -4287,9 +4269,7 @@ DECLARE_GPU_SPEC(complex128);
REGISTER_KERNELS(GPU, Eigen::half);
REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
REGISTER_KERNELS(GPU, complex64);
REGISTER_KERNELS(GPU, complex128);
#endif

View File

@ -118,12 +118,182 @@ struct ApplyGradientDescent<GPUDevice, T> {
}
};
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_complex.h"
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
// process completes without errors,but the resulting executable ends up
// unusable (throwing errors "no device code available for function" for
/// completely unrelated kernels.)
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
// not provide sqrt for complex.
// We have no choice but to implement sqrt and rsqrt by hand
template <typename T>
__device__ T impl_sqrt(T x) {
return sqrt(x);
}
template <typename T>
__device__ T impl_rsqrt(T x) {
return rsqrt(x);
}
template <>
__device__ Eigen::half impl_sqrt(Eigen::half x) {
return __float2half(sqrt(__half2float(x)));
}
template <>
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
return __float2half(rsqrt(__half2float(x)));
}
template <class T>
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T mod_x = sqrt(re * re + im * im);
const T root2 = 0.7071067811865475;
// We pick the root with the same sign of the imaginary component as
// the input.
T root[2] = {T(sqrt(mod_x + re) * root2),
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <class T>
__device__ T rsqrt_helper(T x) {
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
}
template <class T>
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T r = rsqrt(re * re + im * im);
T ar2 = re * r * r;
const T root2 = 0.7071067811865475;
T root[2];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
? rsqrt_helper(im * im * r * r)
: 1 + re * r)) *
root2;
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
? rsqrt_helper(im * im * r * r)
: 1 - re * r)) *
root2 * (im >= 0 ? -1. : 1.);
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <typename T>
__global__ void ApplyAdagradKernel(GpuLaunchConfig cfg, T* var, T* accum,
const T* lr, const T* grad,
bool update_slots) {
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
if (update_slots) accum[i] += grad[i] * grad[i];
var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]);
}
}
template <typename T>
__global__ void ApplyAdagradV2Kernel(GpuLaunchConfig cfg, T* var, T* accum,
const T* lr, const T* epsilon,
const T* grad, bool update_slots) {
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
if (update_slots) accum[i] += grad[i] * grad[i];
T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]);
var[i] -= lr[0] * update;
}
}
template <typename T>
__global__ void ApplyAdadeltaKernel(GpuLaunchConfig cfg, T* var, T* accum,
T* accum_update, const T* plr,
const T* prho, const T* peps,
const T* grad) {
T rho = prho[0];
T eps = peps[0];
T lr = plr[0];
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
T update =
impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps);
var[i] -= update * lr;
accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
}
}
template <typename T>
__global__ void ApplyRMSPropKernel(GpuLaunchConfig cfg, T* var, T* ms, T* mom,
const T* plr, const T* prho,
const T* pmomentum, const T* peps,
const T* grad) {
T rho = prho[0];
T eps = peps[0];
T lr = plr[0];
T momentum = pmomentum[0];
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]);
var[i] -= mom[i];
}
}
template <typename T>
__global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg,
T* ms, T* mom, const T* plr,
const T* prho, const T* pmomentum,
const T* peps, const T* grad) {
T rho = prho[0];
T eps = peps[0];
T lr = plr[0];
T momentum = pmomentum[0];
T one_minus_rho = T(1.0) - rho;
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
mg[i] += one_minus_rho * (grad[i] - mg[i]);
T denom = (ms[i] - mg[i] * mg[i]) + eps;
mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom);
var[i] -= mom[i];
}
}
namespace kernel_forward {
bool to_pointers(bool x) { return x; }
template <class T>
typename T::PointerType to_pointers(T& x) {
return x.data();
}
template <class T>
typename T::ConstPointerType to_pointers(const T& x) {
return x.data();
}
template <typename T, typename... CallerArgs, typename... KernelArgs>
void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
CallerArgs... args) {
int32 data_dim = var.dimension(0);
auto config = GetGpuLaunchConfig(data_dim, d);
TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
0, d.stream(), config, var.data(),
to_pointers(args)...));
}
}; // namespace kernel_forward
using kernel_forward::wrap_kernel_call;
#endif
template <typename T>
struct ApplyAdagrad<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstFlat grad, bool update_slots) {
#if TENSORFLOW_USE_ROCM
wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
update_slots);
#else
if (update_slots) {
accum.device(d) += grad.square();
}
@ -131,6 +301,7 @@ struct ApplyAdagrad<GPUDevice, T> {
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
#endif
}
};
@ -141,6 +312,10 @@ struct ApplyAdagradV2<GPUDevice, T> {
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad, bool update_slots) {
#if TENSORFLOW_USE_ROCM
wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
update_slots);
#else
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
@ -150,9 +325,9 @@ struct ApplyAdagradV2<GPUDevice, T> {
const auto update =
grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
#endif
}
};
template <typename T>
struct ApplyAdadelta<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@ -162,6 +337,10 @@ struct ApplyAdadelta<GPUDevice, T> {
typename TTypes<T>::ConstScalar rho,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
#if TENSORFLOW_USE_ROCM
wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
rho, epsilon, grad);
#else
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
@ -177,6 +356,7 @@ struct ApplyAdadelta<GPUDevice, T> {
accum_update * rho.reshape(single).broadcast(bcast) +
update.square() *
(grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
#endif
}
};
@ -489,6 +669,10 @@ struct ApplyRMSProp<GPUDevice, T> {
typename TTypes<T>::ConstScalar momentum,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
#if TENSORFLOW_USE_ROCM
wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
epsilon, grad);
#else
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
@ -501,6 +685,7 @@ struct ApplyRMSProp<GPUDevice, T> {
lr.reshape(single).broadcast(bcast) * grad /
((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
var.device(d) -= mom;
#endif
}
};
@ -514,6 +699,10 @@ struct ApplyCenteredRMSProp<GPUDevice, T> {
typename TTypes<T>::ConstScalar momentum,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
#if TENSORFLOW_USE_ROCM
wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
rho, momentum, epsilon, grad);
#else
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
@ -526,6 +715,7 @@ struct ApplyCenteredRMSProp<GPUDevice, T> {
mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
var.device(d) -= mom;
#endif
}
};
@ -599,9 +789,8 @@ struct ApplyPowerSign<GPUDevice, T> {
template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
template struct functor::ApplyGradientDescent<GPUDevice, float>;
template struct functor::ApplyGradientDescent<GPUDevice, double>;
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
// complex sqrt
template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
#endif
@ -609,9 +798,8 @@ template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
template struct functor::ApplyAdagrad<GPUDevice, float>;
template struct functor::ApplyAdagrad<GPUDevice, double>;
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
// complex sqrt
template struct functor::ApplyAdagrad<GPUDevice, complex64>;
template struct functor::ApplyAdagrad<GPUDevice, complex128>;
#endif
@ -619,9 +807,8 @@ template struct functor::ApplyAdagrad<GPUDevice, complex128>;
template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
template struct functor::ApplyAdagradV2<GPUDevice, float>;
template struct functor::ApplyAdagradV2<GPUDevice, double>;
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
// complex sqrt
template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
#endif
@ -629,9 +816,8 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
template struct functor::ApplyAdadelta<GPUDevice, float>;
template struct functor::ApplyAdadelta<GPUDevice, double>;
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
// complex sqrt
template struct functor::ApplyAdadelta<GPUDevice, complex64>;
template struct functor::ApplyAdadelta<GPUDevice, complex128>;
#endif
@ -710,9 +896,8 @@ template struct functor::ApplyAdaMax<GPUDevice, double>;
template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
template struct functor::ApplyRMSProp<GPUDevice, float>;
template struct functor::ApplyRMSProp<GPUDevice, double>;
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
// complex sqrt
template struct functor::ApplyRMSProp<GPUDevice, complex64>;
template struct functor::ApplyRMSProp<GPUDevice, complex128>;
#endif
@ -720,9 +905,8 @@ template struct functor::ApplyRMSProp<GPUDevice, complex128>;
template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
#if !defined(TENSORFLOW_USE_NVCC) && \
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
// complex sqrt
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
// complex sqrt
template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
#endif

View File

@ -36,7 +36,7 @@ from tensorflow.python.platform import test
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
# TODO(b/143684500): Eigen to support complex sqrt
if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
if not test_util.IsBuiltWithNvcc():
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]

View File

@ -39,7 +39,7 @@ from tensorflow.python.platform import test
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
# TODO(b/143684500): Eigen to support complex sqrt
if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
if not test_util.IsBuiltWithNvcc():
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]

View File

@ -41,7 +41,7 @@ from tensorflow.python.platform import test
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
# TODO(b/143684500): Eigen to support complex sqrt
if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
if not test_util.IsBuiltWithNvcc():
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
_TEST_PARAM_VALUES = [