diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index b0c8a8312ba..d7897e942c6 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -172,74 +172,70 @@ struct SparseApplyAdagrad { typename TTypes::ConstVec indices, int64 inner_dim, bool update_slots) { const Tindex N = static_cast(indices.dimension(0)); + if (N == 0) return Status::OK(); const Tindex first_dim_size = static_cast(var.dimension(0)); const T lr_scalar = lr(); - if (N > 0) { - const int in_bytes = inner_dim * sizeof(T) * 3; - const int out_bytes = inner_dim * sizeof(T) * 2; - const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost() * 2 + - Eigen::TensorOpCost::MulCost() * 2); - const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles); + const int in_bytes = inner_dim * sizeof(T) * 3; + const int out_bytes = inner_dim * sizeof(T) * 2; + const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost() * 2 + + Eigen::TensorOpCost::MulCost() * 2); + const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles); - if (inner_dim > 1) { - for (Tindex i = 0; i < N; ++i) { - const Tindex index = internal::SubtleMustCopy(indices(i)); - if (!FastBoundsCheck(index, first_dim_size)) { - return errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range")); - } + if (inner_dim > 1) { + for (Tindex i = 0; i < N; ++i) { + const Tindex index = internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, first_dim_size)) { + return errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range")); } - - const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { - for (Tindex i = start_idx; i < end_idx; ++i) { - const Tindex index = internal::SubtleMustCopy(indices(i)); - auto a = accum.template chip<0>(index); - auto g = grad.template chip<0>(i); - auto v = var.template chip<0>(index); - if (update_slots) { - a += g.square(); - } - if (has_epsilon) { - v -= g.constant(lr_scalar) * g / - (a.sqrt() + a.constant(epsilon())); - } else { - v -= g.constant(lr_scalar) * g * a.rsqrt(); - } - } - }; - - d.parallelFor(N, cost, shard); - - } else { - for (Tindex i = 0; i < N; ++i) { - const Tindex index = internal::SubtleMustCopy(indices(i)); - if (!FastBoundsCheck(index, first_dim_size)) { - return errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range")); - } - } - - const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { - for (Tindex i = start_idx; i < end_idx; ++i) { - const Tindex index = internal::SubtleMustCopy(indices(i)); - T& a = accum(index); - const T& g = grad(i); - if (update_slots) { - a += g * g; - } - if (has_epsilon) { - var(index) -= - lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon()); - } else { - var(index) -= lr_scalar * g / Eigen::numext::sqrt(a); - } - } - }; - - d.parallelFor(N, cost, shard); } + + const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { + for (Tindex i = start_idx; i < end_idx; ++i) { + const Tindex index = internal::SubtleMustCopy(indices(i)); + auto a = accum.template chip<0>(index); + auto g = grad.template chip<0>(i); + auto v = var.template chip<0>(index); + if (update_slots) { + a += g.square(); + } + if (has_epsilon) { + v -= g.constant(lr_scalar) * g / (a.sqrt() + a.constant(epsilon())); + } else { + v -= g.constant(lr_scalar) * g * a.rsqrt(); + } + } + }; + + d.parallelFor(N, cost, shard); + } else { + for (Tindex i = 0; i < N; ++i) { + const Tindex index = internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, first_dim_size)) { + return errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range")); + } + } + + const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { + for (Tindex i = start_idx; i < end_idx; ++i) { + const Tindex index = internal::SubtleMustCopy(indices(i)); + T& a = accum(index); + const T& g = grad(i); + if (update_slots) { + a += g * g; + } + if (has_epsilon) { + var(index) -= lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon()); + } else { + var(index) -= lr_scalar * g / Eigen::numext::sqrt(a); + } + } + }; + + d.parallelFor(N, cost, shard); } return Status::OK(); @@ -285,61 +281,60 @@ struct SparseApplyProximalAdagrad { typename TTypes::ConstVec indices, int64 inner_dim) { const Tindex N = static_cast(indices.dimension(0)); + if (N == 0) return Status::OK(); const Tindex first_dim_size = static_cast(var.dimension(0)); const T lr_scalar = lr(); const T l1_scalar = l1(); const T l2_scalar = l2(); - if (N > 0) { - if (inner_dim > 1) { - for (Tindex i = 0; i < N; i++) { - const Tindex index = internal::SubtleMustCopy(indices(i)); - if (!FastBoundsCheck(index, first_dim_size)) { - return errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range")); - } - auto a = accum.template chip<0>(index); - auto g = grad.template chip<0>(i); - auto v = var.template chip<0>(index); - a += g.square(); - // compute learning_rate for current step. - auto learning_rate = a.constant(lr_scalar) * a.rsqrt(); - auto prox_v = v; - // v = w - g * learning_rate. - prox_v -= g * learning_rate; - if (l1_scalar > 0) { - // compute sign(v) * max(|v|, 0) - v = prox_v.sign() * - (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) - .cwiseMax(static_cast(0.0)) / - (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); - } else { - v = prox_v / - (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); - } + if (inner_dim > 1) { + for (Tindex i = 0; i < N; i++) { + const Tindex index = internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, first_dim_size)) { + return errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range")); } - } else { - for (Tindex i = 0; i < N; i++) { - const Tindex index = internal::SubtleMustCopy(indices(i)); - if (!FastBoundsCheck(index, first_dim_size)) { - return errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range")); - } - T& a = accum(index); - const T& g = grad(i); - a += g * g; - auto learning_rate = lr_scalar / std::sqrt(a); - auto prox_v = var(index); - prox_v -= learning_rate * g; - if (l1_scalar > 0) { - var(index) = sgn(prox_v) * - std::max(std::abs(prox_v) - learning_rate * l1_scalar, - static_cast(0.0)) / - (1.0 + l2_scalar * learning_rate); - } else { - var(index) = prox_v / (1.0 + l2_scalar * learning_rate); - } + auto a = accum.template chip<0>(index); + auto g = grad.template chip<0>(i); + auto v = var.template chip<0>(index); + a += g.square(); + // compute learning_rate for current step. + auto learning_rate = a.constant(lr_scalar) * a.rsqrt(); + auto prox_v = v; + // v = w - g * learning_rate. + prox_v -= g * learning_rate; + if (l1_scalar > 0) { + // compute sign(v) * max(|v|, 0) + v = prox_v.sign() * + (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) + .cwiseMax(static_cast(0.0)) / + (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); + } else { + v = prox_v / + (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); + } + } + } else { + for (Tindex i = 0; i < N; i++) { + const Tindex index = internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, first_dim_size)) { + return errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range")); + } + T& a = accum(index); + const T& g = grad(i); + a += g * g; + auto learning_rate = lr_scalar / std::sqrt(a); + auto prox_v = var(index); + prox_v -= learning_rate * g; + if (l1_scalar > 0) { + var(index) = sgn(prox_v) * + std::max(std::abs(prox_v) - learning_rate * l1_scalar, + static_cast(0.0)) / + (1.0 + l2_scalar * learning_rate); + } else { + var(index) = prox_v / (1.0 + l2_scalar * learning_rate); } } } diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index 12abeba22d1..886d6760a08 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -94,6 +94,7 @@ struct ApplyAdagradDA { template struct SparseApplyAdagrad { + // Note that epsilon is ignored if has_epsilon is false. Status operator()(const Device& d, typename TTypes::Matrix var, typename TTypes::Matrix accum, typename TTypes::ConstScalar lr,