Address minor review comments from PR 44919

- Added comment, removed newline, used early exit to reduce indentation.
This commit is contained in:
Ben Barsdell 2020-11-23 18:14:03 +11:00
parent 503b948570
commit 7c90b9e60d
2 changed files with 109 additions and 113 deletions

View File

@ -172,74 +172,70 @@ struct SparseApplyAdagrad<CPUDevice, T, Tindex, has_epsilon> {
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
bool update_slots) {
const Tindex N = static_cast<Tindex>(indices.dimension(0));
if (N == 0) return Status::OK();
const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
const T lr_scalar = lr();
if (N > 0) {
const int in_bytes = inner_dim * sizeof(T) * 3;
const int out_bytes = inner_dim * sizeof(T) * 2;
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
Eigen::TensorOpCost::MulCost<T>() * 2);
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
const int in_bytes = inner_dim * sizeof(T) * 3;
const int out_bytes = inner_dim * sizeof(T) * 2;
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
Eigen::TensorOpCost::MulCost<T>() * 2);
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
if (inner_dim > 1) {
for (Tindex i = 0; i < N; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
if (inner_dim > 1) {
for (Tindex i = 0; i < N; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
for (Tindex i = start_idx; i < end_idx; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
auto a = accum.template chip<0>(index);
auto g = grad.template chip<0>(i);
auto v = var.template chip<0>(index);
if (update_slots) {
a += g.square();
}
if (has_epsilon) {
v -= g.constant(lr_scalar) * g /
(a.sqrt() + a.constant(epsilon()));
} else {
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
}
};
d.parallelFor(N, cost, shard);
} else {
for (Tindex i = 0; i < N; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
}
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
for (Tindex i = start_idx; i < end_idx; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
T& a = accum(index);
const T& g = grad(i);
if (update_slots) {
a += g * g;
}
if (has_epsilon) {
var(index) -=
lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon());
} else {
var(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
}
}
};
d.parallelFor(N, cost, shard);
}
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
for (Tindex i = start_idx; i < end_idx; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
auto a = accum.template chip<0>(index);
auto g = grad.template chip<0>(i);
auto v = var.template chip<0>(index);
if (update_slots) {
a += g.square();
}
if (has_epsilon) {
v -= g.constant(lr_scalar) * g / (a.sqrt() + a.constant(epsilon()));
} else {
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
}
};
d.parallelFor(N, cost, shard);
} else {
for (Tindex i = 0; i < N; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
}
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
for (Tindex i = start_idx; i < end_idx; ++i) {
const Tindex index = internal::SubtleMustCopy(indices(i));
T& a = accum(index);
const T& g = grad(i);
if (update_slots) {
a += g * g;
}
if (has_epsilon) {
var(index) -= lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon());
} else {
var(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
}
}
};
d.parallelFor(N, cost, shard);
}
return Status::OK();
@ -285,61 +281,60 @@ struct SparseApplyProximalAdagrad<CPUDevice, T, Tindex> {
typename TTypes<Tindex>::ConstVec indices,
int64 inner_dim) {
const Tindex N = static_cast<Tindex>(indices.dimension(0));
if (N == 0) return Status::OK();
const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
const T lr_scalar = lr();
const T l1_scalar = l1();
const T l2_scalar = l2();
if (N > 0) {
if (inner_dim > 1) {
for (Tindex i = 0; i < N; i++) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
auto a = accum.template chip<0>(index);
auto g = grad.template chip<0>(i);
auto v = var.template chip<0>(index);
a += g.square();
// compute learning_rate for current step.
auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
auto prox_v = v;
// v = w - g * learning_rate.
prox_v -= g * learning_rate;
if (l1_scalar > 0) {
// compute sign(v) * max(|v|, 0)
v = prox_v.sign() *
(prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
.cwiseMax(static_cast<T>(0.0)) /
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
} else {
v = prox_v /
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
}
if (inner_dim > 1) {
for (Tindex i = 0; i < N; i++) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
} else {
for (Tindex i = 0; i < N; i++) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
T& a = accum(index);
const T& g = grad(i);
a += g * g;
auto learning_rate = lr_scalar / std::sqrt(a);
auto prox_v = var(index);
prox_v -= learning_rate * g;
if (l1_scalar > 0) {
var(index) = sgn(prox_v) *
std::max(std::abs(prox_v) - learning_rate * l1_scalar,
static_cast<T>(0.0)) /
(1.0 + l2_scalar * learning_rate);
} else {
var(index) = prox_v / (1.0 + l2_scalar * learning_rate);
}
auto a = accum.template chip<0>(index);
auto g = grad.template chip<0>(i);
auto v = var.template chip<0>(index);
a += g.square();
// compute learning_rate for current step.
auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
auto prox_v = v;
// v = w - g * learning_rate.
prox_v -= g * learning_rate;
if (l1_scalar > 0) {
// compute sign(v) * max(|v|, 0)
v = prox_v.sign() *
(prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
.cwiseMax(static_cast<T>(0.0)) /
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
} else {
v = prox_v /
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
}
}
} else {
for (Tindex i = 0; i < N; i++) {
const Tindex index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, first_dim_size)) {
return errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range"));
}
T& a = accum(index);
const T& g = grad(i);
a += g * g;
auto learning_rate = lr_scalar / std::sqrt(a);
auto prox_v = var(index);
prox_v -= learning_rate * g;
if (l1_scalar > 0) {
var(index) = sgn(prox_v) *
std::max(std::abs(prox_v) - learning_rate * l1_scalar,
static_cast<T>(0.0)) /
(1.0 + l2_scalar * learning_rate);
} else {
var(index) = prox_v / (1.0 + l2_scalar * learning_rate);
}
}
}

View File

@ -94,6 +94,7 @@ struct ApplyAdagradDA {
template <typename Device, typename T, typename Tindex, bool has_epsilon>
struct SparseApplyAdagrad {
// Note that epsilon is ignored if has_epsilon is false.
Status operator()(const Device& d, typename TTypes<T>::Matrix var,
typename TTypes<T>::Matrix accum,
typename TTypes<T>::ConstScalar lr,