diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index b6c37707acc..7451004911e 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1426,7 +1426,16 @@ class SparseApplyAdagradOp : public OpKernel { errors::InvalidArgument( "Inner dimension should be greater than zero.")); + // This op is implemented only for CPU device. + const auto& d = ctx->eigen_cpu_device(); + if (N > 0) { + const int in_bytes = inner_dim * sizeof(T) * 3; + const int out_bytes = inner_dim * sizeof(T) * 2; + const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost() * 2 + + Eigen::TensorOpCost::MulCost() * 2); + const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles); + if (inner_dim > 1) { const Tindex first_dim_size = var.dim_size(0); auto indices_vec = indices.vec(); @@ -1435,22 +1444,29 @@ class SparseApplyAdagradOp : public OpKernel { auto grad_flat = grad.flat_outer_dims(); T lr_scalar = lr.scalar()(); - // Note(yonghui): It might be worth multi-threading square() and - // rsqrt(). - for (Tindex i = 0; i < N; i++) { + for (Tindex i = 0; i < N; ++i) { const Tindex index = internal::SubtleMustCopy(indices_vec(i)); OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), errors::InvalidArgument( strings::StrCat("Index ", index, " at offset ", i, " in indices is out of range"))); - auto a = accum_flat.template chip<0>(index); - auto g = grad_flat.template chip<0>(i); - auto v = var_flat.template chip<0>(index); - if (update_slots_) { - a += g.square(); - } - v -= g.constant(lr_scalar) * g * a.rsqrt(); } + + const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { + for (Tindex i = start_idx; i < end_idx; ++i) { + const Tindex index = internal::SubtleMustCopy(indices_vec(i)); + auto a = accum_flat.template chip<0>(index); + auto g = grad_flat.template chip<0>(i); + auto v = var_flat.template chip<0>(index); + if (update_slots_) { + a += g.square(); + } + v -= g.constant(lr_scalar) * g * a.rsqrt(); + } + }; + + d.parallelFor(N, cost, shard); + } else { auto indices_vec = indices.vec(); auto var_flat = var.flat(); @@ -1459,19 +1475,27 @@ class SparseApplyAdagradOp : public OpKernel { T lr_scalar = lr.scalar()(); const Tindex first_dim_size = accum_flat.size(); - for (Tindex i = 0; i < N; i++) { + for (Tindex i = 0; i < N; ++i) { const Tindex index = internal::SubtleMustCopy(indices_vec(i)); OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), errors::InvalidArgument( strings::StrCat("Index ", index, " at offset ", i, " in indices is out of range"))); - T& a = accum_flat(index); - const T& g = grad_flat(i); - if (update_slots_) { - a += g * g; - } - var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a); } + + const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { + for (Tindex i = start_idx; i < end_idx; ++i) { + const Tindex index = internal::SubtleMustCopy(indices_vec(i)); + T& a = accum_flat(index); + const T& g = grad_flat(i); + if (update_slots_) { + a += g * g; + } + var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a); + } + }; + + d.parallelFor(N, cost, shard); } }