Add GPU kernels for SparseApply[Proximal]Adagrad
- Also applies to Resource and V2 versions of the ops.
This commit is contained in:
parent
97024506de
commit
28162ecc0c
@ -1942,6 +1942,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
|||||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||||
|
|
||||||
#undef REGISTER_CPU_KERNELS
|
#undef REGISTER_CPU_KERNELS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T, Tindex) \
|
||||||
|
template <> \
|
||||||
|
Status \
|
||||||
|
SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T>::Matrix var, \
|
||||||
|
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
|
||||||
|
typename TTypes<T>::ConstScalar epsilon, \
|
||||||
|
typename TTypes<T>::ConstMatrix grad, \
|
||||||
|
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \
|
||||||
|
bool update_slots); \
|
||||||
|
extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
|
||||||
|
/*has_epsilon=*/false>;
|
||||||
|
DECLARE_GPU_SPEC(Eigen::half, int32);
|
||||||
|
DECLARE_GPU_SPEC(Eigen::half, int64);
|
||||||
|
DECLARE_GPU_SPEC(float, int32);
|
||||||
|
DECLARE_GPU_SPEC(float, int64);
|
||||||
|
DECLARE_GPU_SPEC(double, int32);
|
||||||
|
DECLARE_GPU_SPEC(double, int64);
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
REGISTER_KERNELS(GPU, Eigen::half, int32);
|
||||||
|
REGISTER_KERNELS(GPU, Eigen::half, int64);
|
||||||
|
REGISTER_KERNELS(GPU, float, int32);
|
||||||
|
REGISTER_KERNELS(GPU, float, int64);
|
||||||
|
REGISTER_KERNELS(GPU, double, int32);
|
||||||
|
REGISTER_KERNELS(GPU, double, int64);
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
template <typename Device, typename T, typename Tindex>
|
template <typename Device, typename T, typename Tindex>
|
||||||
@ -2043,6 +2075,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
|||||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||||
|
|
||||||
#undef REGISTER_CPU_KERNELS
|
#undef REGISTER_CPU_KERNELS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T, Tindex) \
|
||||||
|
template <> \
|
||||||
|
Status \
|
||||||
|
SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T>::Matrix var, \
|
||||||
|
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
|
||||||
|
typename TTypes<T>::ConstScalar epsilon, \
|
||||||
|
typename TTypes<T>::ConstMatrix grad, \
|
||||||
|
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim, \
|
||||||
|
bool update_slots); \
|
||||||
|
extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
|
||||||
|
/*has_epsilon=*/true>;
|
||||||
|
DECLARE_GPU_SPEC(Eigen::half, int32);
|
||||||
|
DECLARE_GPU_SPEC(Eigen::half, int64);
|
||||||
|
DECLARE_GPU_SPEC(float, int32);
|
||||||
|
DECLARE_GPU_SPEC(float, int64);
|
||||||
|
DECLARE_GPU_SPEC(double, int32);
|
||||||
|
DECLARE_GPU_SPEC(double, int64);
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
REGISTER_KERNELS(GPU, Eigen::half, int32);
|
||||||
|
REGISTER_KERNELS(GPU, Eigen::half, int64);
|
||||||
|
REGISTER_KERNELS(GPU, float, int32);
|
||||||
|
REGISTER_KERNELS(GPU, float, int64);
|
||||||
|
REGISTER_KERNELS(GPU, double, int32);
|
||||||
|
REGISTER_KERNELS(GPU, double, int64);
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
template <typename Device, typename T, typename Tindex>
|
template <typename Device, typename T, typename Tindex>
|
||||||
@ -2158,6 +2222,34 @@ REGISTER_KERNELS(CPU, float, int64);
|
|||||||
REGISTER_KERNELS(CPU, double, int32);
|
REGISTER_KERNELS(CPU, double, int32);
|
||||||
REGISTER_KERNELS(CPU, double, int64);
|
REGISTER_KERNELS(CPU, double, int64);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T, Tindex) \
|
||||||
|
template <> \
|
||||||
|
Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T>::Matrix var, \
|
||||||
|
typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
|
||||||
|
typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
|
||||||
|
typename TTypes<T>::ConstMatrix grad, \
|
||||||
|
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim); \
|
||||||
|
extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>;
|
||||||
|
DECLARE_GPU_SPEC(Eigen::half, int32);
|
||||||
|
DECLARE_GPU_SPEC(Eigen::half, int64);
|
||||||
|
DECLARE_GPU_SPEC(float, int32);
|
||||||
|
DECLARE_GPU_SPEC(float, int64);
|
||||||
|
DECLARE_GPU_SPEC(double, int32);
|
||||||
|
DECLARE_GPU_SPEC(double, int64);
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
REGISTER_KERNELS(GPU, Eigen::half, int32);
|
||||||
|
REGISTER_KERNELS(GPU, Eigen::half, int64);
|
||||||
|
REGISTER_KERNELS(GPU, float, int32);
|
||||||
|
REGISTER_KERNELS(GPU, float, int64);
|
||||||
|
REGISTER_KERNELS(GPU, double, int32);
|
||||||
|
REGISTER_KERNELS(GPU, double, int64);
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
|
@ -110,6 +110,85 @@ __device__ T impl_sign(T x) {
|
|||||||
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
|
return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Tindex, bool has_epsilon>
|
||||||
|
__global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
|
||||||
|
T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
|
||||||
|
const Tindex* indices, Tindex param_rows, Tindex updates_size,
|
||||||
|
Tindex indices_size, bool update_slots) {
|
||||||
|
Tindex col_size = updates_size / indices_size;
|
||||||
|
GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
|
||||||
|
Tindex indices_row = grad_index / col_size;
|
||||||
|
Tindex param_row = indices[indices_row];
|
||||||
|
if (param_row < 0 || param_row >= param_rows) {
|
||||||
|
// Ignore indices that are out of range.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the index of var and accum.
|
||||||
|
Tindex param_index = param_row * col_size + (grad_index % col_size);
|
||||||
|
|
||||||
|
// Read variables.
|
||||||
|
T var_i = var[param_index];
|
||||||
|
T accum_i = accum[param_index];
|
||||||
|
T grad_i = grad[grad_index];
|
||||||
|
const T lr_t = *lr;
|
||||||
|
const T epsilon_t = *epsilon;
|
||||||
|
|
||||||
|
if (update_slots) {
|
||||||
|
accum_i += grad_i * grad_i;
|
||||||
|
}
|
||||||
|
if (has_epsilon) {
|
||||||
|
var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t);
|
||||||
|
} else {
|
||||||
|
var_i -= lr_t * grad_i * impl_rsqrt(accum_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write update back to variables.
|
||||||
|
var[param_index] = var_i;
|
||||||
|
accum[param_index] = accum_i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Tindex>
|
||||||
|
__global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
|
||||||
|
T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
|
||||||
|
const Tindex* indices, Tindex param_rows, Tindex updates_size,
|
||||||
|
Tindex indices_size) {
|
||||||
|
Tindex col_size = updates_size / indices_size;
|
||||||
|
GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
|
||||||
|
Tindex indices_row = grad_index / col_size;
|
||||||
|
Tindex param_row = indices[indices_row];
|
||||||
|
if (param_row < 0 || param_row >= param_rows) {
|
||||||
|
// Ignore indices that are out of range.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the index of var and accum.
|
||||||
|
Tindex param_index = param_row * col_size + (grad_index % col_size);
|
||||||
|
|
||||||
|
// Read variables.
|
||||||
|
T var_i = var[param_index];
|
||||||
|
T accum_i = accum[param_index];
|
||||||
|
T grad_i = grad[grad_index];
|
||||||
|
const T lr_t = *lr;
|
||||||
|
const T l1_t = *l1;
|
||||||
|
const T l2_t = *l2;
|
||||||
|
|
||||||
|
accum_i += grad_i * grad_i;
|
||||||
|
T learning_rate = lr_t * impl_rsqrt(accum_i);
|
||||||
|
// compute v = w - lr * grad.
|
||||||
|
T prox_var_i = var_i - grad_i * learning_rate;
|
||||||
|
// compute sign(v) * max(|v| - lr * max(l1, 0), 0)
|
||||||
|
var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
|
||||||
|
max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
|
||||||
|
(T(1.) + l2_t * learning_rate);
|
||||||
|
|
||||||
|
// Write update back to variables.
|
||||||
|
var[param_index] = var_i;
|
||||||
|
accum[param_index] = accum_i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename Tindex, bool has_l2_shrinkage>
|
template <typename T, typename Tindex, bool has_l2_shrinkage>
|
||||||
__global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
|
__global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
|
||||||
const T* l1, const T* l2,
|
const T* l1, const T* l2,
|
||||||
@ -421,6 +500,27 @@ struct ApplyAdagradV2<GPUDevice, T> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Tindex, bool has_epsilon>
|
||||||
|
struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
|
||||||
|
Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
|
||||||
|
typename TTypes<T>::Matrix accum,
|
||||||
|
typename TTypes<T>::ConstScalar lr,
|
||||||
|
typename TTypes<T>::ConstScalar epsilon,
|
||||||
|
typename TTypes<T>::ConstMatrix grad,
|
||||||
|
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
|
||||||
|
bool update_slots) {
|
||||||
|
const Tindex first_dim_size = var.dimension(0);
|
||||||
|
const Tindex grad_size = grad.size();
|
||||||
|
const Tindex indices_size = indices.size();
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
||||||
|
return GpuLaunchKernel(
|
||||||
|
SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
|
||||||
|
config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
|
||||||
|
lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
|
||||||
|
grad_size, indices_size, update_slots);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ApplyProximalAdagrad<GPUDevice, T> {
|
struct ApplyProximalAdagrad<GPUDevice, T> {
|
||||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||||
@ -457,6 +557,28 @@ struct ApplyProximalAdagrad<GPUDevice, T> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Tindex>
|
||||||
|
struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
|
||||||
|
Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
|
||||||
|
typename TTypes<T>::Matrix accum,
|
||||||
|
typename TTypes<T>::ConstScalar lr,
|
||||||
|
typename TTypes<T>::ConstScalar l1,
|
||||||
|
typename TTypes<T>::ConstScalar l2,
|
||||||
|
typename TTypes<T>::ConstMatrix grad,
|
||||||
|
typename TTypes<Tindex>::ConstVec indices,
|
||||||
|
int64 inner_dim) {
|
||||||
|
const Tindex first_dim_size = var.dimension(0);
|
||||||
|
const Tindex grad_size = grad.size();
|
||||||
|
const Tindex indices_size = indices.size();
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
|
||||||
|
return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
|
||||||
|
config.block_count, config.thread_per_block, 0,
|
||||||
|
d.stream(), var.data(), accum.data(), lr.data(),
|
||||||
|
l1.data(), l2.data(), grad.data(), indices.data(),
|
||||||
|
first_dim_size, grad_size, indices_size);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ApplyAdadelta<GPUDevice, T> {
|
struct ApplyAdadelta<GPUDevice, T> {
|
||||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
|
||||||
@ -973,10 +1095,33 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
|
|||||||
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
|
||||||
|
template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
|
||||||
|
/*has_epsilon=*/false>; \
|
||||||
|
template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
|
||||||
|
/*has_epsilon=*/false>; \
|
||||||
|
template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
|
||||||
|
/*has_epsilon=*/true>; \
|
||||||
|
template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
|
||||||
|
/*has_epsilon=*/true>
|
||||||
|
EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
|
||||||
|
EXPLICITLY_INSTANTIATE_FUNCTOR(float);
|
||||||
|
EXPLICITLY_INSTANTIATE_FUNCTOR(double);
|
||||||
|
#undef EXPLICITLY_INSTANTIATE_FUNCTOR
|
||||||
|
|
||||||
template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
|
template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
|
||||||
template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
|
template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
|
||||||
template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
|
template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
|
||||||
|
|
||||||
|
template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
|
||||||
|
int32>;
|
||||||
|
template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
|
||||||
|
int64>;
|
||||||
|
template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
|
||||||
|
template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
|
||||||
|
template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
|
||||||
|
template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
|
||||||
|
|
||||||
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
|
template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
|
||||||
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
template struct functor::ApplyAdadelta<GPUDevice, float>;
|
||||||
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
template struct functor::ApplyAdadelta<GPUDevice, double>;
|
||||||
|
Loading…
Reference in New Issue
Block a user