Merge pull request #45818 from benbarsdell:gpu-SparseApplyProximalAdagrad-gpu-proper

PiperOrigin-RevId: 351148337
Change-Id: Ifa95dfde1376137c8de1c17b9763fddaaff4de5d
This commit is contained in:
TensorFlower Gardener 2021-01-11 07:28:21 -08:00
commit 9a7c57d1d8
3 changed files with 321 additions and 84 deletions

View File

@ -1942,6 +1942,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, Tindex) \
template <> \
Status \
SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \
const GPUDevice& d, typename TTypes<T>::Matrix var, \
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
typename TTypes<T>::ConstScalar epsilon, \
typename TTypes<T>::ConstMatrix grad, \
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \
bool update_slots); \
extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
/*has_epsilon=*/false>;
DECLARE_GPU_SPEC(Eigen::half, int32);
DECLARE_GPU_SPEC(Eigen::half, int64);
DECLARE_GPU_SPEC(float, int32);
DECLARE_GPU_SPEC(float, int64);
DECLARE_GPU_SPEC(double, int32);
DECLARE_GPU_SPEC(double, int64);
#undef DECLARE_GPU_SPEC
} // namespace functor
REGISTER_KERNELS(GPU, Eigen::half, int32);
REGISTER_KERNELS(GPU, Eigen::half, int64);
REGISTER_KERNELS(GPU, float, int32);
REGISTER_KERNELS(GPU, float, int64);
REGISTER_KERNELS(GPU, double, int32);
REGISTER_KERNELS(GPU, double, int64);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_KERNELS
template <typename Device, typename T, typename Tindex>
@ -2043,6 +2075,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, Tindex) \
template <> \
Status \
SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \
const GPUDevice& d, typename TTypes<T>::Matrix var, \
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
typename TTypes<T>::ConstScalar epsilon, \
typename TTypes<T>::ConstMatrix grad, \
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \
bool update_slots); \
extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
/*has_epsilon=*/true>;
DECLARE_GPU_SPEC(Eigen::half, int32);
DECLARE_GPU_SPEC(Eigen::half, int64);
DECLARE_GPU_SPEC(float, int32);
DECLARE_GPU_SPEC(float, int64);
DECLARE_GPU_SPEC(double, int32);
DECLARE_GPU_SPEC(double, int64);
#undef DECLARE_GPU_SPEC
} // namespace functor
REGISTER_KERNELS(GPU, Eigen::half, int32);
REGISTER_KERNELS(GPU, Eigen::half, int64);
REGISTER_KERNELS(GPU, float, int32);
REGISTER_KERNELS(GPU, float, int64);
REGISTER_KERNELS(GPU, double, int32);
REGISTER_KERNELS(GPU, double, int64);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_KERNELS
template <typename Device, typename T, typename Tindex>
@ -2158,6 +2222,34 @@ REGISTER_KERNELS(CPU, float, int64);
REGISTER_KERNELS(CPU, double, int32);
REGISTER_KERNELS(CPU, double, int64);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, Tindex) \
template <> \
Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()( \
const GPUDevice& d, typename TTypes<T>::Matrix var, \
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
typename TTypes<T>::ConstMatrix grad, \
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim); \
extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>;
DECLARE_GPU_SPEC(Eigen::half, int32);
DECLARE_GPU_SPEC(Eigen::half, int64);
DECLARE_GPU_SPEC(float, int32);
DECLARE_GPU_SPEC(float, int64);
DECLARE_GPU_SPEC(double, int32);
DECLARE_GPU_SPEC(double, int64);
#undef DECLARE_GPU_SPEC
} // namespace functor
REGISTER_KERNELS(GPU, Eigen::half, int32);
REGISTER_KERNELS(GPU, Eigen::half, int64);
REGISTER_KERNELS(GPU, float, int32);
REGISTER_KERNELS(GPU, float, int64);
REGISTER_KERNELS(GPU, double, int32);
REGISTER_KERNELS(GPU, double, int64);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_KERNELS
template <typename Device, typename T>

View File

@ -27,6 +27,168 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_complex.h"
#endif // TENSORFLOW_USE_ROCM
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
// process completes without errors,but the resulting executable ends up
// unusable (throwing errors "no device code available for function" for
/// completely unrelated kernels.)
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
// not provide sqrt for complex.
// We have no choice but to implement sqrt and rsqrt by hand
template <typename T>
__device__ T impl_sqrt(T x) {
return sqrt(x);
}
template <typename T>
__device__ T impl_rsqrt(T x) {
return rsqrt(x);
}
template <>
__device__ Eigen::half impl_sqrt(Eigen::half x) {
return __float2half(sqrt(__half2float(x)));
}
template <>
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
return __float2half(rsqrt(__half2float(x)));
}
template <class T>
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T mod_x = sqrt(re * re + im * im);
const T root2 = 0.7071067811865475;
// We pick the root with the same sign of the imaginary component as
// the input.
T root[2] = {T(sqrt(mod_x + re) * root2),
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <class T>
__device__ T rsqrt_helper(T x) {
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
}
template <class T>
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T r = rsqrt(re * re + im * im);
T ar2 = re * r * r;
const T root2 = 0.7071067811865475;
T root[2];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 + re * r))) *
root2;
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 - re * r))) *
root2 * (im >= 0 ? -1. : 1.);
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <typename T>
__device__ T impl_fabs(T x) {
return fabs(x);
}
template <>
__device__ Eigen::half impl_fabs(Eigen::half x) {
return __float2half(fabs(__half2float(x)));
}
template <typename T>
__device__ T impl_sign(T x) {
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
}
template <typename T, typename Tindex, bool has_epsilon>
__global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
const Tindex* indices, Tindex param_rows, Tindex updates_size,
Tindex indices_size, bool update_slots) {
Tindex col_size = updates_size / indices_size;
GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
Tindex indices_row = grad_index / col_size;
Tindex param_row = indices[indices_row];
if (param_row < 0 || param_row >= param_rows) {
// Ignore indices that are out of range.
continue;
}
// Compute the index of var and accum.
Tindex param_index = param_row * col_size + (grad_index % col_size);
// Read variables.
T var_i = var[param_index];
T accum_i = accum[param_index];
T grad_i = grad[grad_index];
const T lr_t = *lr;
const T epsilon_t = *epsilon;
if (update_slots) {
accum_i += grad_i * grad_i;
}
if (has_epsilon) {
var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t);
} else {
var_i -= lr_t * grad_i * impl_rsqrt(accum_i);
}
// Write update back to variables.
var[param_index] = var_i;
accum[param_index] = accum_i;
}
}
template <typename T, typename Tindex>
__global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
const Tindex* indices, Tindex param_rows, Tindex updates_size,
Tindex indices_size) {
Tindex col_size = updates_size / indices_size;
GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
Tindex indices_row = grad_index / col_size;
Tindex param_row = indices[indices_row];
if (param_row < 0 || param_row >= param_rows) {
// Ignore indices that are out of range.
continue;
}
// Compute the index of var and accum.
Tindex param_index = param_row * col_size + (grad_index % col_size);
// Read variables.
T var_i = var[param_index];
T accum_i = accum[param_index];
T grad_i = grad[grad_index];
const T lr_t = *lr;
const T l1_t = *l1;
const T l2_t = *l2;
accum_i += grad_i * grad_i;
T learning_rate = lr_t * impl_rsqrt(accum_i);
// compute v = w - lr * grad.
T prox_var_i = var_i - grad_i * learning_rate;
// compute sign(v) * max(|v| - lr * max(l1, 0), 0)
var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
(T(1.) + l2_t * learning_rate);
// Write update back to variables.
var[param_index] = var_i;
accum[param_index] = accum_i;
}
}
template <typename T, typename Tindex, bool has_l2_shrinkage>
__global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
const T* l1, const T* l2,
@ -183,87 +345,6 @@ struct ApplyGradientDescent<GPUDevice, T> {
}
};
#if TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_complex.h"
// if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
// process completes without errors,but the resulting executable ends up
// unusable (throwing errors "no device code available for function" for
/// completely unrelated kernels.)
// We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
// not provide sqrt for complex.
// We have no choice but to implement sqrt and rsqrt by hand
template <typename T>
__device__ T impl_sqrt(T x) {
return sqrt(x);
}
template <typename T>
__device__ T impl_rsqrt(T x) {
return rsqrt(x);
}
template <>
__device__ Eigen::half impl_sqrt(Eigen::half x) {
return __float2half(sqrt(__half2float(x)));
}
template <>
__device__ Eigen::half impl_rsqrt(Eigen::half x) {
return __float2half(rsqrt(__half2float(x)));
}
template <class T>
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T mod_x = sqrt(re * re + im * im);
const T root2 = 0.7071067811865475;
// We pick the root with the same sign of the imaginary component as
// the input.
T root[2] = {T(sqrt(mod_x + re) * root2),
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <class T>
__device__ T rsqrt_helper(T x) {
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
}
template <class T>
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T r = rsqrt(re * re + im * im);
T ar2 = re * r * r;
const T root2 = 0.7071067811865475;
T root[2];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 + re * r))) *
root2;
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
? rsqrt_helper(im * im * r * r)
: max(T(0.0), 1 - re * r))) *
root2 * (im >= 0 ? -1. : 1.);
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <typename T>
__device__ T impl_fabs(T x) {
return fabs(x);
}
template <>
__device__ Eigen::half impl_fabs(Eigen::half x) {
return __float2half(fabs(__half2float(x)));
}
template <typename T>
__device__ T impl_sign(T x) {
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
}
template <typename T>
__global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
T* var, T* accum,
@ -374,8 +455,6 @@ void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
using kernel_forward::wrap_kernel_call;
#endif
template <typename T>
struct ApplyAdagrad<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@ -421,6 +500,27 @@ struct ApplyAdagradV2<GPUDevice, T> {
}
};
template <typename T, typename Tindex, bool has_epsilon>
struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
typename TTypes<T>::Matrix accum,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstMatrix grad,
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
bool update_slots) {
const Tindex first_dim_size = var.dimension(0);
const Tindex grad_size = grad.size();
const Tindex indices_size = indices.size();
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
return GpuLaunchKernel(
SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
grad_size, indices_size, update_slots);
}
};
template <typename T>
struct ApplyProximalAdagrad<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@ -457,6 +557,28 @@ struct ApplyProximalAdagrad<GPUDevice, T> {
}
};
template <typename T, typename Tindex>
struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
typename TTypes<T>::Matrix accum,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstScalar l1,
typename TTypes<T>::ConstScalar l2,
typename TTypes<T>::ConstMatrix grad,
typename TTypes<Tindex>::ConstVec indices,
int64 inner_dim) {
const Tindex first_dim_size = var.dimension(0);
const Tindex grad_size = grad.size();
const Tindex indices_size = indices.size();
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
config.block_count, config.thread_per_block, 0,
d.stream(), var.data(), accum.data(), lr.data(),
l1.data(), l2.data(), grad.data(), indices.data(),
first_dim_size, grad_size, indices_size);
}
};
template <typename T>
struct ApplyAdadelta<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@ -973,10 +1095,33 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
#endif
#define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
/*has_epsilon=*/false>; \
template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
/*has_epsilon=*/false>; \
template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
/*has_epsilon=*/true>; \
template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
/*has_epsilon=*/true>
EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
EXPLICITLY_INSTANTIATE_FUNCTOR(float);
EXPLICITLY_INSTANTIATE_FUNCTOR(double);
#undef EXPLICITLY_INSTANTIATE_FUNCTOR
template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
int32>;
template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
int64>;
template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
template struct functor::ApplyAdadelta<GPUDevice, float>;
template struct functor::ApplyAdadelta<GPUDevice, double>;

View File

@ -225,7 +225,7 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
self.setUp()
with self.session(use_gpu=False):
with self.session(use_gpu=True):
var = variables.VariableV1(x)
accum = variables.VariableV1(y)
self.evaluate(variables.global_variables_initializer())