Merge pull request #45818 from benbarsdell:gpu-SparseApplyProximalAdagrad-gpu-proper
PiperOrigin-RevId: 351148337 Change-Id: Ifa95dfde1376137c8de1c17b9763fddaaff4de5d
This commit is contained in:
commit
9a7c57d1d8
@ -1942,6 +1942,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T, Tindex) \
|
||||
template <> \
|
||||
Status \
|
||||
SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::Matrix var, \
|
||||
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
|
||||
typename TTypes<T>::ConstScalar epsilon, \
|
||||
typename TTypes<T>::ConstMatrix grad, \
|
||||
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \
|
||||
bool update_slots); \
|
||||
extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
|
||||
/*has_epsilon=*/false>;
|
||||
DECLARE_GPU_SPEC(Eigen::half, int32);
|
||||
DECLARE_GPU_SPEC(Eigen::half, int64);
|
||||
DECLARE_GPU_SPEC(float, int32);
|
||||
DECLARE_GPU_SPEC(float, int64);
|
||||
DECLARE_GPU_SPEC(double, int32);
|
||||
DECLARE_GPU_SPEC(double, int64);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNELS(GPU, Eigen::half, int32);
|
||||
REGISTER_KERNELS(GPU, Eigen::half, int64);
|
||||
REGISTER_KERNELS(GPU, float, int32);
|
||||
REGISTER_KERNELS(GPU, float, int64);
|
||||
REGISTER_KERNELS(GPU, double, int32);
|
||||
REGISTER_KERNELS(GPU, double, int64);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T, typename Tindex>
|
||||
@ -2043,6 +2075,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T, Tindex) \
|
||||
template <> \
|
||||
Status \
|
||||
SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::Matrix var, \
|
||||
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
|
||||
typename TTypes<T>::ConstScalar epsilon, \
|
||||
typename TTypes<T>::ConstMatrix grad, \
|
||||
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \
|
||||
bool update_slots); \
|
||||
extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
|
||||
/*has_epsilon=*/true>;
|
||||
DECLARE_GPU_SPEC(Eigen::half, int32);
|
||||
DECLARE_GPU_SPEC(Eigen::half, int64);
|
||||
DECLARE_GPU_SPEC(float, int32);
|
||||
DECLARE_GPU_SPEC(float, int64);
|
||||
DECLARE_GPU_SPEC(double, int32);
|
||||
DECLARE_GPU_SPEC(double, int64);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNELS(GPU, Eigen::half, int32);
|
||||
REGISTER_KERNELS(GPU, Eigen::half, int64);
|
||||
REGISTER_KERNELS(GPU, float, int32);
|
||||
REGISTER_KERNELS(GPU, float, int64);
|
||||
REGISTER_KERNELS(GPU, double, int32);
|
||||
REGISTER_KERNELS(GPU, double, int64);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T, typename Tindex>
|
||||
@ -2158,6 +2222,34 @@ REGISTER_KERNELS(CPU, float, int64);
|
||||
REGISTER_KERNELS(CPU, double, int32);
|
||||
REGISTER_KERNELS(CPU, double, int64);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T, Tindex) \
|
||||
template <> \
|
||||
Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::Matrix var, \
|
||||
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
|
||||
typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
|
||||
typename TTypes<T>::ConstMatrix grad, \
|
||||
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim); \
|
||||
extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>;
|
||||
DECLARE_GPU_SPEC(Eigen::half, int32);
|
||||
DECLARE_GPU_SPEC(Eigen::half, int64);
|
||||
DECLARE_GPU_SPEC(float, int32);
|
||||
DECLARE_GPU_SPEC(float, int64);
|
||||
DECLARE_GPU_SPEC(double, int32);
|
||||
DECLARE_GPU_SPEC(double, int64);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNELS(GPU, Eigen::half, int32);
|
||||
REGISTER_KERNELS(GPU, Eigen::half, int64);
|
||||
REGISTER_KERNELS(GPU, float, int32);
|
||||
REGISTER_KERNELS(GPU, float, int64);
|
||||
REGISTER_KERNELS(GPU, double, int32);
|
||||
REGISTER_KERNELS(GPU, double, int64);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
@ -27,6 +27,168 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
|
||||
#endif // TENSORFLOW_USE_ROCM
|
||||
|
||||
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
|
||||
// process completes without errors,but the resulting executable ends up
|
||||
// unusable (throwing errors "no device code available for function" for
|
||||
/// completely unrelated kernels.)
|
||||
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
|
||||
// not provide sqrt for complex.
|
||||
// We have no choice but to implement sqrt and rsqrt by hand
|
||||
template <typename T>
|
||||
__device__ T impl_sqrt(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
template <typename T>
|
||||
__device__ T impl_rsqrt(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_sqrt(Eigen::half x) {
|
||||
return __float2half(sqrt(__half2float(x)));
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
|
||||
return __float2half(rsqrt(__half2float(x)));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T mod_x = sqrt(re * re + im * im);
|
||||
const T root2 = 0.7071067811865475;
|
||||
// We pick the root with the same sign of the imaginary component as
|
||||
// the input.
|
||||
T root[2] = {T(sqrt(mod_x + re) * root2),
|
||||
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
|
||||
// hcc/clang is really weird with its support of complex in device code;
|
||||
// for some reason it does not permit a 2-argument constructor
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ T rsqrt_helper(T x) {
|
||||
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T r = rsqrt(re * re + im * im);
|
||||
T ar2 = re * r * r;
|
||||
const T root2 = 0.7071067811865475;
|
||||
T root[2];
|
||||
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
|
||||
// due to subtraction of two close values. We have to get fancy
|
||||
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: max(T(0.0), 1 + re * r))) *
|
||||
root2;
|
||||
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: max(T(0.0), 1 - re * r))) *
|
||||
root2 * (im >= 0 ? -1. : 1.);
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T impl_fabs(T x) {
|
||||
return fabs(x);
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_fabs(Eigen::half x) {
|
||||
return __float2half(fabs(__half2float(x)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T impl_sign(T x) {
|
||||
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
|
||||
}
|
||||
|
||||
template <typename T, typename Tindex, bool has_epsilon>
|
||||
__global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
|
||||
T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
|
||||
const Tindex* indices, Tindex param_rows, Tindex updates_size,
|
||||
Tindex indices_size, bool update_slots) {
|
||||
Tindex col_size = updates_size / indices_size;
|
||||
GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
|
||||
Tindex indices_row = grad_index / col_size;
|
||||
Tindex param_row = indices[indices_row];
|
||||
if (param_row < 0 || param_row >= param_rows) {
|
||||
// Ignore indices that are out of range.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute the index of var and accum.
|
||||
Tindex param_index = param_row * col_size + (grad_index % col_size);
|
||||
|
||||
// Read variables.
|
||||
T var_i = var[param_index];
|
||||
T accum_i = accum[param_index];
|
||||
T grad_i = grad[grad_index];
|
||||
const T lr_t = *lr;
|
||||
const T epsilon_t = *epsilon;
|
||||
|
||||
if (update_slots) {
|
||||
accum_i += grad_i * grad_i;
|
||||
}
|
||||
if (has_epsilon) {
|
||||
var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t);
|
||||
} else {
|
||||
var_i -= lr_t * grad_i * impl_rsqrt(accum_i);
|
||||
}
|
||||
|
||||
// Write update back to variables.
|
||||
var[param_index] = var_i;
|
||||
accum[param_index] = accum_i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tindex>
|
||||
__global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
|
||||
T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
|
||||
const Tindex* indices, Tindex param_rows, Tindex updates_size,
|
||||
Tindex indices_size) {
|
||||
Tindex col_size = updates_size / indices_size;
|
||||
GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
|
||||
Tindex indices_row = grad_index / col_size;
|
||||
Tindex param_row = indices[indices_row];
|
||||
if (param_row < 0 || param_row >= param_rows) {
|
||||
// Ignore indices that are out of range.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute the index of var and accum.
|
||||
Tindex param_index = param_row * col_size + (grad_index % col_size);
|
||||
|
||||
// Read variables.
|
||||
T var_i = var[param_index];
|
||||
T accum_i = accum[param_index];
|
||||
T grad_i = grad[grad_index];
|
||||
const T lr_t = *lr;
|
||||
const T l1_t = *l1;
|
||||
const T l2_t = *l2;
|
||||
|
||||
accum_i += grad_i * grad_i;
|
||||
T learning_rate = lr_t * impl_rsqrt(accum_i);
|
||||
// compute v = w - lr * grad.
|
||||
T prox_var_i = var_i - grad_i * learning_rate;
|
||||
// compute sign(v) * max(|v| - lr * max(l1, 0), 0)
|
||||
var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
|
||||
max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
|
||||
(T(1.) + l2_t * learning_rate);
|
||||
|
||||
// Write update back to variables.
|
||||
var[param_index] = var_i;
|
||||
accum[param_index] = accum_i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Tindex, bool has_l2_shrinkage>
|
||||
__global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
|
||||
const T* l1, const T* l2,
|
||||
@ -183,87 +345,6 @@ struct ApplyGradientDescent<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
|
||||
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
|
||||
// process completes without errors,but the resulting executable ends up
|
||||
// unusable (throwing errors "no device code available for function" for
|
||||
/// completely unrelated kernels.)
|
||||
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
|
||||
// not provide sqrt for complex.
|
||||
// We have no choice but to implement sqrt and rsqrt by hand
|
||||
template <typename T>
|
||||
__device__ T impl_sqrt(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
template <typename T>
|
||||
__device__ T impl_rsqrt(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_sqrt(Eigen::half x) {
|
||||
return __float2half(sqrt(__half2float(x)));
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
|
||||
return __float2half(rsqrt(__half2float(x)));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T mod_x = sqrt(re * re + im * im);
|
||||
const T root2 = 0.7071067811865475;
|
||||
// We pick the root with the same sign of the imaginary component as
|
||||
// the input.
|
||||
T root[2] = {T(sqrt(mod_x + re) * root2),
|
||||
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
|
||||
// hcc/clang is really weird with its support of complex in device code;
|
||||
// for some reason it does not permit a 2-argument constructor
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ T rsqrt_helper(T x) {
|
||||
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T r = rsqrt(re * re + im * im);
|
||||
T ar2 = re * r * r;
|
||||
const T root2 = 0.7071067811865475;
|
||||
T root[2];
|
||||
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
|
||||
// due to subtraction of two close values. We have to get fancy
|
||||
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: max(T(0.0), 1 + re * r))) *
|
||||
root2;
|
||||
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
|
||||
? rsqrt_helper(im * im * r * r)
|
||||
: max(T(0.0), 1 - re * r))) *
|
||||
root2 * (im >= 0 ? -1. : 1.);
|
||||
return *(reinterpret_cast<std::complex<T>*>(&root));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T impl_fabs(T x) {
|
||||
return fabs(x);
|
||||
}
|
||||
template <>
|
||||
__device__ Eigen::half impl_fabs(Eigen::half x) {
|
||||
return __float2half(fabs(__half2float(x)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T impl_sign(T x) {
|
||||
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
|
||||
T* var, T* accum,
|
||||
@ -374,8 +455,6 @@ void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
|
||||
|
||||
using kernel_forward::wrap_kernel_call;
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdagrad<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -421,6 +500,27 @@ struct ApplyAdagradV2<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Tindex, bool has_epsilon>
|
||||
struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
|
||||
Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
|
||||
typename TTypes<T>::Matrix accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar epsilon,
|
||||
typename TTypes<T>::ConstMatrix grad,
|
||||
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
|
||||
bool update_slots) {
|
||||
const Tindex first_dim_size = var.dimension(0);
|
||||
const Tindex grad_size = grad.size();
|
||||
const Tindex indices_size = indices.size();
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
||||
return GpuLaunchKernel(
|
||||
SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
|
||||
lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
|
||||
grad_size, indices_size, update_slots);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyProximalAdagrad<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -457,6 +557,28 @@ struct ApplyProximalAdagrad<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Tindex>
|
||||
struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
|
||||
Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
|
||||
typename TTypes<T>::Matrix accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
typename TTypes<T>::ConstScalar l1,
|
||||
typename TTypes<T>::ConstScalar l2,
|
||||
typename TTypes<T>::ConstMatrix grad,
|
||||
typename TTypes<Tindex>::ConstVec indices,
|
||||
int64 inner_dim) {
|
||||
const Tindex first_dim_size = var.dimension(0);
|
||||
const Tindex grad_size = grad.size();
|
||||
const Tindex indices_size = indices.size();
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
||||
return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), var.data(), accum.data(), lr.data(),
|
||||
l1.data(), l2.data(), grad.data(), indices.data(),
|
||||
first_dim_size, grad_size, indices_size);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ApplyAdadelta<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||
@ -973,10 +1095,33 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
|
||||
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
||||
#endif
|
||||
|
||||
#define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
|
||||
template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
|
||||
/*has_epsilon=*/false>; \
|
||||
template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
|
||||
/*has_epsilon=*/false>; \
|
||||
template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
|
||||
/*has_epsilon=*/true>; \
|
||||
template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
|
||||
/*has_epsilon=*/true>
|
||||
EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
|
||||
EXPLICITLY_INSTANTIATE_FUNCTOR(float);
|
||||
EXPLICITLY_INSTANTIATE_FUNCTOR(double);
|
||||
#undef EXPLICITLY_INSTANTIATE_FUNCTOR
|
||||
|
||||
template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
|
||||
template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
|
||||
|
||||
template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
|
||||
int32>;
|
||||
template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
|
||||
int64>;
|
||||
template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
|
||||
template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
|
||||
template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
|
||||
template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
|
||||
|
||||
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
||||
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
||||
|
@ -225,7 +225,7 @@ class TrainingOpsTest(TensorFlowTestCase):
|
||||
|
||||
def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
|
||||
self.setUp()
|
||||
with self.session(use_gpu=False):
|
||||
with self.session(use_gpu=True):
|
||||
var = variables.VariableV1(x)
|
||||
accum = variables.VariableV1(y)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
Loading…
Reference in New Issue
Block a user