diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 12d626baf1f..451ce5be118 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1942,6 +1942,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, Tindex) \ + template <> \ + Status \ + SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \ + const GPUDevice& d, typename TTypes<T>::Matrix var, \ + typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ + typename TTypes<T>::ConstScalar epsilon, \ + typename TTypes<T>::ConstMatrix grad, \ + typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \ + bool update_slots); \ + extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \ + /*has_epsilon=*/false>; +DECLARE_GPU_SPEC(Eigen::half, int32); +DECLARE_GPU_SPEC(Eigen::half, int64); +DECLARE_GPU_SPEC(float, int32); +DECLARE_GPU_SPEC(float, int64); +DECLARE_GPU_SPEC(double, int32); +DECLARE_GPU_SPEC(double, int64); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, Eigen::half, int32); +REGISTER_KERNELS(GPU, Eigen::half, int64); +REGISTER_KERNELS(GPU, float, int32); +REGISTER_KERNELS(GPU, float, int64); +REGISTER_KERNELS(GPU, double, int32); +REGISTER_KERNELS(GPU, double, int64); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_KERNELS template <typename Device, typename T, typename Tindex> @@ -2043,6 +2075,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, Tindex) \ + template <> \ + Status \ + SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \ + const GPUDevice& d, typename TTypes<T>::Matrix var, \ + typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ + typename TTypes<T>::ConstScalar epsilon, \ + typename TTypes<T>::ConstMatrix grad, \ + typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \ + bool update_slots); \ + extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \ + /*has_epsilon=*/true>; +DECLARE_GPU_SPEC(Eigen::half, int32); +DECLARE_GPU_SPEC(Eigen::half, int64); +DECLARE_GPU_SPEC(float, int32); +DECLARE_GPU_SPEC(float, int64); +DECLARE_GPU_SPEC(double, int32); +DECLARE_GPU_SPEC(double, int64); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, Eigen::half, int32); +REGISTER_KERNELS(GPU, Eigen::half, int64); +REGISTER_KERNELS(GPU, float, int32); +REGISTER_KERNELS(GPU, float, int64); +REGISTER_KERNELS(GPU, double, int32); +REGISTER_KERNELS(GPU, double, int64); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_KERNELS template <typename Device, typename T, typename Tindex> @@ -2158,6 +2222,34 @@ REGISTER_KERNELS(CPU, float, int64); REGISTER_KERNELS(CPU, double, int32); REGISTER_KERNELS(CPU, double, int64); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, Tindex) \ + template <> \ + Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()( \ + const GPUDevice& d, typename TTypes<T>::Matrix var, \ + typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ + typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \ + typename TTypes<T>::ConstMatrix grad, \ + typename TTypes<Tindex>::ConstVec indices, int64 inner_dim); \ + extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>; +DECLARE_GPU_SPEC(Eigen::half, int32); +DECLARE_GPU_SPEC(Eigen::half, int64); +DECLARE_GPU_SPEC(float, int32); +DECLARE_GPU_SPEC(float, int64); +DECLARE_GPU_SPEC(double, int32); +DECLARE_GPU_SPEC(double, int64); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, Eigen::half, int32); +REGISTER_KERNELS(GPU, Eigen::half, int64); +REGISTER_KERNELS(GPU, float, int32); +REGISTER_KERNELS(GPU, float, int64); +REGISTER_KERNELS(GPU, double, int32); +REGISTER_KERNELS(GPU, double, int64); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_KERNELS template <typename Device, typename T> diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index d9f14bd4cc5..becd76a124d 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -110,6 +110,85 @@ __device__ T impl_sign(T x) { return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1); } +template <typename T, typename Tindex, bool has_epsilon> +__global__ __launch_bounds__(1024) void SparseApplyAdagradKernel( + T* var, T* accum, const T* lr, const T* epsilon, const T* grad, + const Tindex* indices, Tindex param_rows, Tindex updates_size, + Tindex indices_size, bool update_slots) { + Tindex col_size = updates_size / indices_size; + GPU_1D_KERNEL_LOOP(grad_index, updates_size) { + Tindex indices_row = grad_index / col_size; + Tindex param_row = indices[indices_row]; + if (param_row < 0 || param_row >= param_rows) { + // Ignore indices that are out of range. + continue; + } + + // Compute the index of var and accum. + Tindex param_index = param_row * col_size + (grad_index % col_size); + + // Read variables. + T var_i = var[param_index]; + T accum_i = accum[param_index]; + T grad_i = grad[grad_index]; + const T lr_t = *lr; + const T epsilon_t = *epsilon; + + if (update_slots) { + accum_i += grad_i * grad_i; + } + if (has_epsilon) { + var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t); + } else { + var_i -= lr_t * grad_i * impl_rsqrt(accum_i); + } + + // Write update back to variables. + var[param_index] = var_i; + accum[param_index] = accum_i; + } +} + +template <typename T, typename Tindex> +__global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel( + T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad, + const Tindex* indices, Tindex param_rows, Tindex updates_size, + Tindex indices_size) { + Tindex col_size = updates_size / indices_size; + GPU_1D_KERNEL_LOOP(grad_index, updates_size) { + Tindex indices_row = grad_index / col_size; + Tindex param_row = indices[indices_row]; + if (param_row < 0 || param_row >= param_rows) { + // Ignore indices that are out of range. + continue; + } + + // Compute the index of var and accum. + Tindex param_index = param_row * col_size + (grad_index % col_size); + + // Read variables. + T var_i = var[param_index]; + T accum_i = accum[param_index]; + T grad_i = grad[grad_index]; + const T lr_t = *lr; + const T l1_t = *l1; + const T l2_t = *l2; + + accum_i += grad_i * grad_i; + T learning_rate = lr_t * impl_rsqrt(accum_i); + // compute v = w - lr * grad. + T prox_var_i = var_i - grad_i * learning_rate; + // compute sign(v) * max(|v| - lr * max(l1, 0), 0) + var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) * + max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) / + (T(1.) + l2_t * learning_rate); + + // Write update back to variables. + var[param_index] = var_i; + accum[param_index] = accum_i; + } +} + template <typename T, typename Tindex, bool has_l2_shrinkage> __global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr, const T* l1, const T* l2, @@ -421,6 +500,27 @@ struct ApplyAdagradV2<GPUDevice, T> { } }; +template <typename T, typename Tindex, bool has_epsilon> +struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> { + Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var, + typename TTypes<T>::Matrix accum, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstMatrix grad, + typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, + bool update_slots) { + const Tindex first_dim_size = var.dimension(0); + const Tindex grad_size = grad.size(); + const Tindex indices_size = indices.size(); + GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d); + return GpuLaunchKernel( + SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count, + config.thread_per_block, 0, d.stream(), var.data(), accum.data(), + lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size, + grad_size, indices_size, update_slots); + } +}; + template <typename T> struct ApplyProximalAdagrad<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, @@ -457,6 +557,28 @@ struct ApplyProximalAdagrad<GPUDevice, T> { } }; +template <typename T, typename Tindex> +struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> { + Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var, + typename TTypes<T>::Matrix accum, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar l1, + typename TTypes<T>::ConstScalar l2, + typename TTypes<T>::ConstMatrix grad, + typename TTypes<Tindex>::ConstVec indices, + int64 inner_dim) { + const Tindex first_dim_size = var.dimension(0); + const Tindex grad_size = grad.size(); + const Tindex indices_size = indices.size(); + GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d); + return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>, + config.block_count, config.thread_per_block, 0, + d.stream(), var.data(), accum.data(), lr.data(), + l1.data(), l2.data(), grad.data(), indices.data(), + first_dim_size, grad_size, indices_size); + } +}; + template <typename T> struct ApplyAdadelta<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, @@ -973,10 +1095,33 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex64>; template struct functor::ApplyAdagradV2<GPUDevice, complex128>; #endif +#define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \ + template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \ + /*has_epsilon=*/false>; \ + template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \ + /*has_epsilon=*/false>; \ + template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \ + /*has_epsilon=*/true>; \ + template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \ + /*has_epsilon=*/true> +EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half); +EXPLICITLY_INSTANTIATE_FUNCTOR(float); +EXPLICITLY_INSTANTIATE_FUNCTOR(double); +#undef EXPLICITLY_INSTANTIATE_FUNCTOR + template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>; template struct functor::ApplyProximalAdagrad<GPUDevice, float>; template struct functor::ApplyProximalAdagrad<GPUDevice, double>; +template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half, + int32>; +template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half, + int64>; +template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>; +template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>; +template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>; +template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>; + template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>; template struct functor::ApplyAdadelta<GPUDevice, float>; template struct functor::ApplyAdadelta<GPUDevice, double>;