ROCm-optimized training apply ops
This commit is contained in:
parent
057cf24986
commit
3693256ba0
@ -670,9 +670,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
@ -684,9 +683,8 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
@ -853,9 +851,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
@ -867,9 +864,8 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
@ -1348,9 +1344,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
@ -1362,9 +1357,8 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
@ -1468,9 +1462,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
@ -1482,9 +1475,8 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
@ -4183,9 +4175,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
@ -4197,9 +4188,8 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined( \
|
||||
TENSORFLOW_USE_NVCC) // TODO(b/143684500): Eigen to support complex sqrt
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
|
@ -118,12 +118,184 @@ struct ApplyGradientDescent<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
|
||||
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
|
||||
// process completes without errors,but the resulting executable ends up
|
||||
// unusable (throwing errors "no device code available for function" for
|
||||
/// completely unrelated kernels.)
|
||||
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
|
||||
// not provide sqrt for complex.
|
||||
// We have no choice but to implement sqrt and rsqrt by hand
|
||||
template <typename T>
|
||||
__device__ T impl_sqrt(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
template <typename T>
|
||||
__device__ T impl_rsqrt(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_sqrt(Eigen::half x) {
|
||||
return __float2half(sqrt(__half2float(x)));
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
|
||||
return __float2half(rsqrt(__half2float(x)));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T mod_x = sqrt(re * re + im * im);
|
||||
const T root2 = 0.7071067811865475;
|
||||
// we pick the root with the same sign of the imaginary component as the input
|
||||
T root[2] = {T(sqrt(mod_x + re) * root2),
|
||||
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
|
||||
// hcc/clang is really weird with its support of complex in device code;
|
||||
// for some reason it does not permit a 2-argument constructor
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ T rsqrt_helper(T x) {
|
||||
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T r = rsqrt(re * re + im * im);
|
||||
T ar2 = re * r * r;
|
||||
const T root2 = 0.7071067811865475;
|
||||
T root[2];
|
||||
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
|
||||
// due to subtraction of two close values. We have to get fancy
|
||||
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: 1 + re * r)) *
|
||||
root2;
|
||||
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: 1 - re * r)) *
|
||||
root2 * (im >= 0 ? -1. : 1.);
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyAdagradKernel(GpuLaunchConfig cfg, T* var, T* accum,
|
||||
const T* lr, const T* grad,
|
||||
bool update_slots) {
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
if (update_slots) accum[i] += grad[i] * grad[i];
|
||||
var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyAdagradV2Kernel(GpuLaunchConfig cfg, T* var, T* accum,
|
||||
const T* lr, const T* epsilon,
|
||||
const T* grad, bool update_slots) {
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
if (update_slots) accum[i] += grad[i] * grad[i];
|
||||
T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]);
|
||||
var[i] -= lr[0] * update;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyAdadeltaKernel(GpuLaunchConfig cfg, T* var, T* accum,
|
||||
T* accum_update, const T* plr,
|
||||
const T* prho, const T* peps,
|
||||
const T* grad) {
|
||||
T rho = prho[0];
|
||||
T eps = peps[0];
|
||||
T lr = plr[0];
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
|
||||
T update =
|
||||
impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps);
|
||||
var[i] -= update * lr;
|
||||
accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyRMSPropKernel(GpuLaunchConfig cfg, T* var, T* ms, T* mom,
|
||||
const T* plr, const T* prho,
|
||||
const T* pmomentum, const T* peps,
|
||||
const T* grad) {
|
||||
T rho = prho[0];
|
||||
T eps = peps[0];
|
||||
T lr = plr[0];
|
||||
T momentum = pmomentum[0];
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
|
||||
mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]);
|
||||
var[i] -= mom[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg,
|
||||
T* ms, T* mom, const T* plr,
|
||||
const T* prho, const T* pmomentum,
|
||||
const T* peps, const T* grad) {
|
||||
T rho = prho[0];
|
||||
T eps = peps[0];
|
||||
T lr = plr[0];
|
||||
T momentum = pmomentum[0];
|
||||
T one_minus_rho = T(1.0) - rho;
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
|
||||
mg[i] += one_minus_rho * (grad[i] - mg[i]);
|
||||
T denom = (ms[i] - mg[i] * mg[i]) + eps;
|
||||
mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom);
|
||||
var[i] -= mom[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace kernel_forward {
|
||||
bool to_pointers(bool x) { return x; }
|
||||
template <class T>
|
||||
typename T::PointerType to_pointers(T& x) {
|
||||
return x.data();
|
||||
}
|
||||
template <class T>
|
||||
typename T::ConstPointerType to_pointers(const T& x) {
|
||||
return x.data();
|
||||
}
|
||||
|
||||
template <typename T, typename... CallerArgs, typename... KernelArgs>
|
||||
void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
|
||||
CallerArgs... args) {
|
||||
int32 data_dim = var.dimension(0);
|
||||
auto config = GetGpuLaunchConfig(data_dim, d);
|
||||
TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
|
||||
0, d.stream(), config, var.data(),
|
||||
to_pointers(args)...));
|
||||
}
|
||||
}; // namespace kernel_forward
|
||||
|
||||
using kernel_forward::wrap_kernel_call;
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdagrad<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
|
||||
update_slots);
|
||||
#else
|
||||
if (update_slots) {
|
||||
accum.device(d) += grad.square();
|
||||
}
|
||||
@ -131,6 +303,7 @@ struct ApplyAdagrad<GPUDevice, T> {
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -141,6 +314,10 @@ struct ApplyAdagradV2<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
|
||||
update_slots);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -150,9 +327,9 @@ struct ApplyAdagradV2<GPUDevice, T> {
|
||||
const auto update =
|
||||
grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdadelta<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -162,6 +339,10 @@ struct ApplyAdadelta<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar rho,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
|
||||
rho, epsilon, grad);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -177,6 +358,7 @@ struct ApplyAdadelta<GPUDevice, T> {
|
||||
accum_update * rho.reshape(single).broadcast(bcast) +
|
||||
update.square() *
|
||||
(grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -414,6 +596,10 @@ struct ApplyRMSProp<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar momentum,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
|
||||
epsilon, grad);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -426,6 +612,7 @@ struct ApplyRMSProp<GPUDevice, T> {
|
||||
lr.reshape(single).broadcast(bcast) * grad /
|
||||
((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
|
||||
var.device(d) -= mom;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -439,6 +626,10 @@ struct ApplyCenteredRMSProp<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar momentum,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
|
||||
rho, momentum, epsilon, grad);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -451,6 +642,7 @@ struct ApplyCenteredRMSProp<GPUDevice, T> {
|
||||
mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
|
||||
lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
|
||||
var.device(d) -= mom;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -524,9 +716,7 @@ struct ApplyPowerSign<GPUDevice, T> {
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, float>;
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
|
||||
@ -536,9 +726,7 @@ template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyAdagrad<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, complex128>;
|
||||
@ -548,9 +736,7 @@ template struct functor::ApplyAdagrad<GPUDevice, complex128>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
||||
@ -560,9 +746,7 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyAdadelta<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, complex128>;
|
||||
@ -580,9 +764,7 @@ template struct functor::ApplyFtrlV2<GPUDevice, double>;
|
||||
template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyMomentum<GPUDevice, float>;
|
||||
template struct functor::ApplyMomentum<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyMomentum<GPUDevice, complex64>;
|
||||
template struct functor::ApplyMomentum<GPUDevice, complex128>;
|
||||
@ -592,9 +774,7 @@ template struct functor::ApplyMomentum<GPUDevice, complex128>;
|
||||
template struct functor::ApplyKerasMomentum<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyKerasMomentum<GPUDevice, float>;
|
||||
template struct functor::ApplyKerasMomentum<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyKerasMomentum<GPUDevice, complex64>;
|
||||
template struct functor::ApplyKerasMomentum<GPUDevice, complex128>;
|
||||
@ -609,9 +789,7 @@ template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int32>;
|
||||
template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int64>;
|
||||
template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int32>;
|
||||
template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int64>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int32>;
|
||||
template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int64>;
|
||||
@ -623,9 +801,7 @@ template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int64>;
|
||||
template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdam<GPUDevice, float>;
|
||||
template struct functor::ApplyAdam<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && !defined(TENSORFLOW_USE_ROCM)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyAdam<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdam<GPUDevice, complex128>;
|
||||
@ -643,9 +819,7 @@ template struct functor::ApplyAdaMax<GPUDevice, double>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, float>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyRMSProp<GPUDevice, complex64>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, complex128>;
|
||||
@ -655,9 +829,7 @@ template struct functor::ApplyRMSProp<GPUDevice, complex128>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#if !defined(TENSORFLOW_USE_NVCC)
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
|
||||
|
@ -35,8 +35,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
|
||||
# TODO(b/143684500): Eigen to support complex sqrt
|
||||
if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and
|
||||
not test.is_built_with_rocm()):
|
||||
if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows"):
|
||||
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
|
||||
|
@ -38,8 +38,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
|
||||
# TODO(b/143684500): Eigen to support complex sqrt
|
||||
if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and
|
||||
not test.is_built_with_rocm()):
|
||||
if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows"):
|
||||
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
|
||||
|
@ -41,8 +41,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
|
||||
# TODO(b/143684500): Eigen to support complex sqrt
|
||||
if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows" and
|
||||
not test.is_built_with_rocm()):
|
||||
if (not test_util.IsBuiltWithNvcc() and platform.system() != "Windows"):
|
||||
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
_TEST_PARAM_VALUES = [
|
||||
|
Loading…
Reference in New Issue
Block a user