Address minor review comments from PR 44919
- Added comment, removed newline, used early exit to reduce indentation.
This commit is contained in:
parent
503b948570
commit
7c90b9e60d
@ -172,74 +172,70 @@ struct SparseApplyAdagrad<CPUDevice, T, Tindex, has_epsilon> {
|
||||
typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
|
||||
bool update_slots) {
|
||||
const Tindex N = static_cast<Tindex>(indices.dimension(0));
|
||||
if (N == 0) return Status::OK();
|
||||
const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
|
||||
const T lr_scalar = lr();
|
||||
if (N > 0) {
|
||||
const int in_bytes = inner_dim * sizeof(T) * 3;
|
||||
const int out_bytes = inner_dim * sizeof(T) * 2;
|
||||
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
|
||||
Eigen::TensorOpCost::MulCost<T>() * 2);
|
||||
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
|
||||
const int in_bytes = inner_dim * sizeof(T) * 3;
|
||||
const int out_bytes = inner_dim * sizeof(T) * 2;
|
||||
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
|
||||
Eigen::TensorOpCost::MulCost<T>() * 2);
|
||||
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
|
||||
|
||||
if (inner_dim > 1) {
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
if (inner_dim > 1) {
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
auto a = accum.template chip<0>(index);
|
||||
auto g = grad.template chip<0>(i);
|
||||
auto v = var.template chip<0>(index);
|
||||
if (update_slots) {
|
||||
a += g.square();
|
||||
}
|
||||
if (has_epsilon) {
|
||||
v -= g.constant(lr_scalar) * g /
|
||||
(a.sqrt() + a.constant(epsilon()));
|
||||
} else {
|
||||
v -= g.constant(lr_scalar) * g * a.rsqrt();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
|
||||
} else {
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
T& a = accum(index);
|
||||
const T& g = grad(i);
|
||||
if (update_slots) {
|
||||
a += g * g;
|
||||
}
|
||||
if (has_epsilon) {
|
||||
var(index) -=
|
||||
lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon());
|
||||
} else {
|
||||
var(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
auto a = accum.template chip<0>(index);
|
||||
auto g = grad.template chip<0>(i);
|
||||
auto v = var.template chip<0>(index);
|
||||
if (update_slots) {
|
||||
a += g.square();
|
||||
}
|
||||
if (has_epsilon) {
|
||||
v -= g.constant(lr_scalar) * g / (a.sqrt() + a.constant(epsilon()));
|
||||
} else {
|
||||
v -= g.constant(lr_scalar) * g * a.rsqrt();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
} else {
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
T& a = accum(index);
|
||||
const T& g = grad(i);
|
||||
if (update_slots) {
|
||||
a += g * g;
|
||||
}
|
||||
if (has_epsilon) {
|
||||
var(index) -= lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon());
|
||||
} else {
|
||||
var(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -285,61 +281,60 @@ struct SparseApplyProximalAdagrad<CPUDevice, T, Tindex> {
|
||||
typename TTypes<Tindex>::ConstVec indices,
|
||||
int64 inner_dim) {
|
||||
const Tindex N = static_cast<Tindex>(indices.dimension(0));
|
||||
if (N == 0) return Status::OK();
|
||||
const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
|
||||
const T lr_scalar = lr();
|
||||
const T l1_scalar = l1();
|
||||
const T l2_scalar = l2();
|
||||
if (N > 0) {
|
||||
if (inner_dim > 1) {
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
auto a = accum.template chip<0>(index);
|
||||
auto g = grad.template chip<0>(i);
|
||||
auto v = var.template chip<0>(index);
|
||||
a += g.square();
|
||||
// compute learning_rate for current step.
|
||||
auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
|
||||
auto prox_v = v;
|
||||
// v = w - g * learning_rate.
|
||||
prox_v -= g * learning_rate;
|
||||
if (l1_scalar > 0) {
|
||||
// compute sign(v) * max(|v|, 0)
|
||||
v = prox_v.sign() *
|
||||
(prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
|
||||
.cwiseMax(static_cast<T>(0.0)) /
|
||||
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
|
||||
} else {
|
||||
v = prox_v /
|
||||
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
|
||||
}
|
||||
if (inner_dim > 1) {
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
} else {
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
T& a = accum(index);
|
||||
const T& g = grad(i);
|
||||
a += g * g;
|
||||
auto learning_rate = lr_scalar / std::sqrt(a);
|
||||
auto prox_v = var(index);
|
||||
prox_v -= learning_rate * g;
|
||||
if (l1_scalar > 0) {
|
||||
var(index) = sgn(prox_v) *
|
||||
std::max(std::abs(prox_v) - learning_rate * l1_scalar,
|
||||
static_cast<T>(0.0)) /
|
||||
(1.0 + l2_scalar * learning_rate);
|
||||
} else {
|
||||
var(index) = prox_v / (1.0 + l2_scalar * learning_rate);
|
||||
}
|
||||
auto a = accum.template chip<0>(index);
|
||||
auto g = grad.template chip<0>(i);
|
||||
auto v = var.template chip<0>(index);
|
||||
a += g.square();
|
||||
// compute learning_rate for current step.
|
||||
auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
|
||||
auto prox_v = v;
|
||||
// v = w - g * learning_rate.
|
||||
prox_v -= g * learning_rate;
|
||||
if (l1_scalar > 0) {
|
||||
// compute sign(v) * max(|v|, 0)
|
||||
v = prox_v.sign() *
|
||||
(prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
|
||||
.cwiseMax(static_cast<T>(0.0)) /
|
||||
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
|
||||
} else {
|
||||
v = prox_v /
|
||||
(v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, first_dim_size)) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range"));
|
||||
}
|
||||
T& a = accum(index);
|
||||
const T& g = grad(i);
|
||||
a += g * g;
|
||||
auto learning_rate = lr_scalar / std::sqrt(a);
|
||||
auto prox_v = var(index);
|
||||
prox_v -= learning_rate * g;
|
||||
if (l1_scalar > 0) {
|
||||
var(index) = sgn(prox_v) *
|
||||
std::max(std::abs(prox_v) - learning_rate * l1_scalar,
|
||||
static_cast<T>(0.0)) /
|
||||
(1.0 + l2_scalar * learning_rate);
|
||||
} else {
|
||||
var(index) = prox_v / (1.0 + l2_scalar * learning_rate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -94,6 +94,7 @@ struct ApplyAdagradDA {
|
||||
|
||||
template <typename Device, typename T, typename Tindex, bool has_epsilon>
|
||||
struct SparseApplyAdagrad {
|
||||
// Note that epsilon is ignored if has_epsilon is false.
|
||||
Status operator()(const Device& d, typename TTypes<T>::Matrix var,
|
||||
typename TTypes<T>::Matrix accum,
|
||||
typename TTypes<T>::ConstScalar lr,
|
||||
|
Loading…
x
Reference in New Issue
Block a user