diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 12d626baf1f..451ce5be118 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -1942,6 +1942,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
 TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
 
 #undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Tindex)                                            \
+  template <>                                                                  \
+  Status                                                                       \
+  SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \
+      const GPUDevice& d, typename TTypes<T>::Matrix var,                      \
+      typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr,    \
+      typename TTypes<T>::ConstScalar epsilon,                                 \
+      typename TTypes<T>::ConstMatrix grad,                                    \
+      typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,              \
+      bool update_slots);                                                      \
+  extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex,              \
+                                            /*has_epsilon=*/false>;
+DECLARE_GPU_SPEC(Eigen::half, int32);
+DECLARE_GPU_SPEC(Eigen::half, int64);
+DECLARE_GPU_SPEC(float, int32);
+DECLARE_GPU_SPEC(float, int64);
+DECLARE_GPU_SPEC(double, int32);
+DECLARE_GPU_SPEC(double, int64);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half, int32);
+REGISTER_KERNELS(GPU, Eigen::half, int64);
+REGISTER_KERNELS(GPU, float, int32);
+REGISTER_KERNELS(GPU, float, int64);
+REGISTER_KERNELS(GPU, double, int32);
+REGISTER_KERNELS(GPU, double, int64);
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef REGISTER_KERNELS
 
 template <typename Device, typename T, typename Tindex>
@@ -2043,6 +2075,38 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
 TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
 
 #undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Tindex)                                           \
+  template <>                                                                 \
+  Status                                                                      \
+  SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \
+      const GPUDevice& d, typename TTypes<T>::Matrix var,                     \
+      typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr,   \
+      typename TTypes<T>::ConstScalar epsilon,                                \
+      typename TTypes<T>::ConstMatrix grad,                                   \
+      typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,             \
+      bool update_slots);                                                     \
+  extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex,             \
+                                            /*has_epsilon=*/true>;
+DECLARE_GPU_SPEC(Eigen::half, int32);
+DECLARE_GPU_SPEC(Eigen::half, int64);
+DECLARE_GPU_SPEC(float, int32);
+DECLARE_GPU_SPEC(float, int64);
+DECLARE_GPU_SPEC(double, int32);
+DECLARE_GPU_SPEC(double, int64);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half, int32);
+REGISTER_KERNELS(GPU, Eigen::half, int64);
+REGISTER_KERNELS(GPU, float, int32);
+REGISTER_KERNELS(GPU, float, int64);
+REGISTER_KERNELS(GPU, double, int32);
+REGISTER_KERNELS(GPU, double, int64);
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef REGISTER_KERNELS
 
 template <typename Device, typename T, typename Tindex>
@@ -2158,6 +2222,34 @@ REGISTER_KERNELS(CPU, float, int64);
 REGISTER_KERNELS(CPU, double, int32);
 REGISTER_KERNELS(CPU, double, int64);
 
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Tindex)                                           \
+  template <>                                                                 \
+  Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()(        \
+      const GPUDevice& d, typename TTypes<T>::Matrix var,                     \
+      typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr,   \
+      typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
+      typename TTypes<T>::ConstMatrix grad,                                   \
+      typename TTypes<Tindex>::ConstVec indices, int64 inner_dim);            \
+  extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>;
+DECLARE_GPU_SPEC(Eigen::half, int32);
+DECLARE_GPU_SPEC(Eigen::half, int64);
+DECLARE_GPU_SPEC(float, int32);
+DECLARE_GPU_SPEC(float, int64);
+DECLARE_GPU_SPEC(double, int32);
+DECLARE_GPU_SPEC(double, int64);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half, int32);
+REGISTER_KERNELS(GPU, Eigen::half, int64);
+REGISTER_KERNELS(GPU, float, int32);
+REGISTER_KERNELS(GPU, float, int64);
+REGISTER_KERNELS(GPU, double, int32);
+REGISTER_KERNELS(GPU, double, int64);
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #undef REGISTER_KERNELS
 
 template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index d9f14bd4cc5..becd76a124d 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -110,6 +110,85 @@ __device__ T impl_sign(T x) {
   return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
 }
 
+template <typename T, typename Tindex, bool has_epsilon>
+__global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
+    T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
+    const Tindex* indices, Tindex param_rows, Tindex updates_size,
+    Tindex indices_size, bool update_slots) {
+  Tindex col_size = updates_size / indices_size;
+  GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
+    Tindex indices_row = grad_index / col_size;
+    Tindex param_row = indices[indices_row];
+    if (param_row < 0 || param_row >= param_rows) {
+      // Ignore indices that are out of range.
+      continue;
+    }
+
+    // Compute the index of var and accum.
+    Tindex param_index = param_row * col_size + (grad_index % col_size);
+
+    // Read variables.
+    T var_i = var[param_index];
+    T accum_i = accum[param_index];
+    T grad_i = grad[grad_index];
+    const T lr_t = *lr;
+    const T epsilon_t = *epsilon;
+
+    if (update_slots) {
+      accum_i += grad_i * grad_i;
+    }
+    if (has_epsilon) {
+      var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t);
+    } else {
+      var_i -= lr_t * grad_i * impl_rsqrt(accum_i);
+    }
+
+    // Write update back to variables.
+    var[param_index] = var_i;
+    accum[param_index] = accum_i;
+  }
+}
+
+template <typename T, typename Tindex>
+__global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
+    T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
+    const Tindex* indices, Tindex param_rows, Tindex updates_size,
+    Tindex indices_size) {
+  Tindex col_size = updates_size / indices_size;
+  GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
+    Tindex indices_row = grad_index / col_size;
+    Tindex param_row = indices[indices_row];
+    if (param_row < 0 || param_row >= param_rows) {
+      // Ignore indices that are out of range.
+      continue;
+    }
+
+    // Compute the index of var and accum.
+    Tindex param_index = param_row * col_size + (grad_index % col_size);
+
+    // Read variables.
+    T var_i = var[param_index];
+    T accum_i = accum[param_index];
+    T grad_i = grad[grad_index];
+    const T lr_t = *lr;
+    const T l1_t = *l1;
+    const T l2_t = *l2;
+
+    accum_i += grad_i * grad_i;
+    T learning_rate = lr_t * impl_rsqrt(accum_i);
+    // compute v = w - lr * grad.
+    T prox_var_i = var_i - grad_i * learning_rate;
+    // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
+    var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
+            max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
+            (T(1.) + l2_t * learning_rate);
+
+    // Write update back to variables.
+    var[param_index] = var_i;
+    accum[param_index] = accum_i;
+  }
+}
+
 template <typename T, typename Tindex, bool has_l2_shrinkage>
 __global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
                                       const T* l1, const T* l2,
@@ -421,6 +500,27 @@ struct ApplyAdagradV2<GPUDevice, T> {
   }
 };
 
+template <typename T, typename Tindex, bool has_epsilon>
+struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
+  Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar epsilon,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
+                    bool update_slots) {
+    const Tindex first_dim_size = var.dimension(0);
+    const Tindex grad_size = grad.size();
+    const Tindex indices_size = indices.size();
+    GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
+    return GpuLaunchKernel(
+        SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
+        config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
+        lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
+        grad_size, indices_size, update_slots);
+  }
+};
+
 template <typename T>
 struct ApplyProximalAdagrad<GPUDevice, T> {
   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@@ -457,6 +557,28 @@ struct ApplyProximalAdagrad<GPUDevice, T> {
   }
 };
 
+template <typename T, typename Tindex>
+struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
+  Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
+                    typename TTypes<T>::Matrix accum,
+                    typename TTypes<T>::ConstScalar lr,
+                    typename TTypes<T>::ConstScalar l1,
+                    typename TTypes<T>::ConstScalar l2,
+                    typename TTypes<T>::ConstMatrix grad,
+                    typename TTypes<Tindex>::ConstVec indices,
+                    int64 inner_dim) {
+    const Tindex first_dim_size = var.dimension(0);
+    const Tindex grad_size = grad.size();
+    const Tindex indices_size = indices.size();
+    GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
+    return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
+                           config.block_count, config.thread_per_block, 0,
+                           d.stream(), var.data(), accum.data(), lr.data(),
+                           l1.data(), l2.data(), grad.data(), indices.data(),
+                           first_dim_size, grad_size, indices_size);
+  }
+};
+
 template <typename T>
 struct ApplyAdadelta<GPUDevice, T> {
   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
@@ -973,10 +1095,33 @@ template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
 template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
 #endif
 
+#define EXPLICITLY_INSTANTIATE_FUNCTOR(T)                             \
+  template struct functor::SparseApplyAdagrad<GPUDevice, T, int32,    \
+                                              /*has_epsilon=*/false>; \
+  template struct functor::SparseApplyAdagrad<GPUDevice, T, int64,    \
+                                              /*has_epsilon=*/false>; \
+  template struct functor::SparseApplyAdagrad<GPUDevice, T, int32,    \
+                                              /*has_epsilon=*/true>;  \
+  template struct functor::SparseApplyAdagrad<GPUDevice, T, int64,    \
+                                              /*has_epsilon=*/true>
+EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
+EXPLICITLY_INSTANTIATE_FUNCTOR(float);
+EXPLICITLY_INSTANTIATE_FUNCTOR(double);
+#undef EXPLICITLY_INSTANTIATE_FUNCTOR
+
 template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
 template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
 template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
 
+template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
+                                                    int32>;
+template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
+                                                    int64>;
+template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
+template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
+template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
+template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
+
 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
 template struct functor::ApplyAdadelta<GPUDevice, float>;
 template struct functor::ApplyAdadelta<GPUDevice, double>;