Merge pull request #35924 from ROCmSoftwarePlatform:google_upstream_training_ops
PiperOrigin-RevId: 306484877 Change-Id: Ibcf72e35c40a8a5d3570168d8e48aca77b86e36f
This commit is contained in:
commit
3b29db9951
@ -747,9 +747,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
@ -759,9 +757,7 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
#endif
|
||||
@ -924,9 +920,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
@ -936,9 +930,7 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
#endif
|
||||
@ -1411,9 +1403,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
@ -1423,9 +1413,7 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
#endif
|
||||
@ -1525,9 +1513,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
@ -1537,9 +1523,7 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
#endif
|
||||
@ -4275,9 +4259,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(Eigen::half);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
@ -4287,9 +4269,7 @@ DECLARE_GPU_SPEC(complex128);
|
||||
REGISTER_KERNELS(GPU, Eigen::half);
|
||||
REGISTER_KERNELS(GPU, float);
|
||||
REGISTER_KERNELS(GPU, double);
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support complex sqrt
|
||||
REGISTER_KERNELS(GPU, complex64);
|
||||
REGISTER_KERNELS(GPU, complex128);
|
||||
#endif
|
||||
|
@ -118,12 +118,182 @@ struct ApplyGradientDescent<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
|
||||
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
|
||||
// process completes without errors,but the resulting executable ends up
|
||||
// unusable (throwing errors "no device code available for function" for
|
||||
/// completely unrelated kernels.)
|
||||
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
|
||||
// not provide sqrt for complex.
|
||||
// We have no choice but to implement sqrt and rsqrt by hand
|
||||
template <typename T>
|
||||
__device__ T impl_sqrt(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
template <typename T>
|
||||
__device__ T impl_rsqrt(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_sqrt(Eigen::half x) {
|
||||
return __float2half(sqrt(__half2float(x)));
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
|
||||
return __float2half(rsqrt(__half2float(x)));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T mod_x = sqrt(re * re + im * im);
|
||||
const T root2 = 0.7071067811865475;
|
||||
// We pick the root with the same sign of the imaginary component as
|
||||
// the input.
|
||||
T root[2] = {T(sqrt(mod_x + re) * root2),
|
||||
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
|
||||
// hcc/clang is really weird with its support of complex in device code;
|
||||
// for some reason it does not permit a 2-argument constructor
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ T rsqrt_helper(T x) {
|
||||
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T r = rsqrt(re * re + im * im);
|
||||
T ar2 = re * r * r;
|
||||
const T root2 = 0.7071067811865475;
|
||||
T root[2];
|
||||
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
|
||||
// due to subtraction of two close values. We have to get fancy
|
||||
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: 1 + re * r)) *
|
||||
root2;
|
||||
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: 1 - re * r)) *
|
||||
root2 * (im >= 0 ? -1. : 1.);
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyAdagradKernel(GpuLaunchConfig cfg, T* var, T* accum,
|
||||
const T* lr, const T* grad,
|
||||
bool update_slots) {
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
if (update_slots) accum[i] += grad[i] * grad[i];
|
||||
var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyAdagradV2Kernel(GpuLaunchConfig cfg, T* var, T* accum,
|
||||
const T* lr, const T* epsilon,
|
||||
const T* grad, bool update_slots) {
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
if (update_slots) accum[i] += grad[i] * grad[i];
|
||||
T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]);
|
||||
var[i] -= lr[0] * update;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyAdadeltaKernel(GpuLaunchConfig cfg, T* var, T* accum,
|
||||
T* accum_update, const T* plr,
|
||||
const T* prho, const T* peps,
|
||||
const T* grad) {
|
||||
T rho = prho[0];
|
||||
T eps = peps[0];
|
||||
T lr = plr[0];
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
|
||||
T update =
|
||||
impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps);
|
||||
var[i] -= update * lr;
|
||||
accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyRMSPropKernel(GpuLaunchConfig cfg, T* var, T* ms, T* mom,
|
||||
const T* plr, const T* prho,
|
||||
const T* pmomentum, const T* peps,
|
||||
const T* grad) {
|
||||
T rho = prho[0];
|
||||
T eps = peps[0];
|
||||
T lr = plr[0];
|
||||
T momentum = pmomentum[0];
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
|
||||
mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]);
|
||||
var[i] -= mom[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg,
|
||||
T* ms, T* mom, const T* plr,
|
||||
const T* prho, const T* pmomentum,
|
||||
const T* peps, const T* grad) {
|
||||
T rho = prho[0];
|
||||
T eps = peps[0];
|
||||
T lr = plr[0];
|
||||
T momentum = pmomentum[0];
|
||||
T one_minus_rho = T(1.0) - rho;
|
||||
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
|
||||
ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
|
||||
mg[i] += one_minus_rho * (grad[i] - mg[i]);
|
||||
T denom = (ms[i] - mg[i] * mg[i]) + eps;
|
||||
mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom);
|
||||
var[i] -= mom[i];
|
||||
}
|
||||
}
|
||||
|
||||
namespace kernel_forward {
|
||||
bool to_pointers(bool x) { return x; }
|
||||
template <class T>
|
||||
typename T::PointerType to_pointers(T& x) {
|
||||
return x.data();
|
||||
}
|
||||
template <class T>
|
||||
typename T::ConstPointerType to_pointers(const T& x) {
|
||||
return x.data();
|
||||
}
|
||||
|
||||
template <typename T, typename... CallerArgs, typename... KernelArgs>
|
||||
void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
|
||||
CallerArgs... args) {
|
||||
int32 data_dim = var.dimension(0);
|
||||
auto config = GetGpuLaunchConfig(data_dim, d);
|
||||
TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
|
||||
0, d.stream(), config, var.data(),
|
||||
to_pointers(args)...));
|
||||
}
|
||||
}; // namespace kernel_forward
|
||||
|
||||
using kernel_forward::wrap_kernel_call;
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdagrad<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
typename TTypes<T>::Flat accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
|
||||
update_slots);
|
||||
#else
|
||||
if (update_slots) {
|
||||
accum.device(d) += grad.square();
|
||||
}
|
||||
@ -131,6 +301,7 @@ struct ApplyAdagrad<GPUDevice, T> {
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -141,6 +312,10 @@ struct ApplyAdagradV2<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad, bool update_slots) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
|
||||
update_slots);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -150,9 +325,9 @@ struct ApplyAdagradV2<GPUDevice, T> {
|
||||
const auto update =
|
||||
grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
|
||||
var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdadelta<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -162,6 +337,10 @@ struct ApplyAdadelta<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar rho,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
|
||||
rho, epsilon, grad);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -177,6 +356,7 @@ struct ApplyAdadelta<GPUDevice, T> {
|
||||
accum_update * rho.reshape(single).broadcast(bcast) +
|
||||
update.square() *
|
||||
(grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -489,6 +669,10 @@ struct ApplyRMSProp<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar momentum,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
|
||||
epsilon, grad);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -501,6 +685,7 @@ struct ApplyRMSProp<GPUDevice, T> {
|
||||
lr.reshape(single).broadcast(bcast) * grad /
|
||||
((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
|
||||
var.device(d) -= mom;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -514,6 +699,10 @@ struct ApplyCenteredRMSProp<GPUDevice, T> {
|
||||
typename TTypes<T>::ConstScalar momentum,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstFlat grad) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
|
||||
rho, momentum, epsilon, grad);
|
||||
#else
|
||||
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
|
||||
bcast[0] = grad.dimension(0);
|
||||
Eigen::Sizes<1> single;
|
||||
@ -526,6 +715,7 @@ struct ApplyCenteredRMSProp<GPUDevice, T> {
|
||||
mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
|
||||
lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
|
||||
var.device(d) -= mom;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -599,9 +789,8 @@ struct ApplyPowerSign<GPUDevice, T> {
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, float>;
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
|
||||
template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
|
||||
#endif
|
||||
@ -609,9 +798,8 @@ template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
template struct functor::ApplyAdagrad<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdagrad<GPUDevice, complex128>;
|
||||
#endif
|
||||
@ -619,9 +807,8 @@ template struct functor::ApplyAdagrad<GPUDevice, complex128>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, float>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
||||
#endif
|
||||
@ -629,9 +816,8 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
template struct functor::ApplyAdadelta<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, complex128>;
|
||||
#endif
|
||||
@ -710,9 +896,8 @@ template struct functor::ApplyAdaMax<GPUDevice, double>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, float>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
template struct functor::ApplyRMSProp<GPUDevice, complex64>;
|
||||
template struct functor::ApplyRMSProp<GPUDevice, complex128>;
|
||||
#endif
|
||||
@ -720,9 +905,8 @@ template struct functor::ApplyRMSProp<GPUDevice, complex128>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
|
||||
#if !defined(TENSORFLOW_USE_NVCC) && \
|
||||
!defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
#ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
|
||||
// complex sqrt
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
|
||||
template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
|
||||
#endif
|
||||
|
@ -36,7 +36,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
|
||||
# TODO(b/143684500): Eigen to support complex sqrt
|
||||
if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
|
||||
if not test_util.IsBuiltWithNvcc():
|
||||
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
|
||||
# TODO(b/143684500): Eigen to support complex sqrt
|
||||
if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
|
||||
if not test_util.IsBuiltWithNvcc():
|
||||
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
|
||||
# TODO(b/143684500): Eigen to support complex sqrt
|
||||
if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
|
||||
if not test_util.IsBuiltWithNvcc():
|
||||
_DATA_TYPES += [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
_TEST_PARAM_VALUES = [
|
||||
|
Loading…
Reference in New Issue
Block a user