Parallelize SparseApplyAdagradOp
PiperOrigin-RevId: 254106236
This commit is contained in:
parent
223d17868c
commit
c174697c09
@ -1426,7 +1426,16 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Inner dimension should be greater than zero."));
|
||||
|
||||
// This op is implemented only for CPU device.
|
||||
const auto& d = ctx->eigen_cpu_device();
|
||||
|
||||
if (N > 0) {
|
||||
const int in_bytes = inner_dim * sizeof(T) * 3;
|
||||
const int out_bytes = inner_dim * sizeof(T) * 2;
|
||||
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
|
||||
Eigen::TensorOpCost::MulCost<T>() * 2);
|
||||
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
|
||||
|
||||
if (inner_dim > 1) {
|
||||
const Tindex first_dim_size = var.dim_size(0);
|
||||
auto indices_vec = indices.vec<Tindex>();
|
||||
@ -1435,22 +1444,29 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
auto grad_flat = grad.flat_outer_dims<T>();
|
||||
T lr_scalar = lr.scalar<T>()();
|
||||
|
||||
// Note(yonghui): It might be worth multi-threading square() and
|
||||
// rsqrt().
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
auto a = accum_flat.template chip<0>(index);
|
||||
auto g = grad_flat.template chip<0>(i);
|
||||
auto v = var_flat.template chip<0>(index);
|
||||
if (update_slots_) {
|
||||
a += g.square();
|
||||
}
|
||||
v -= g.constant(lr_scalar) * g * a.rsqrt();
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
auto a = accum_flat.template chip<0>(index);
|
||||
auto g = grad_flat.template chip<0>(i);
|
||||
auto v = var_flat.template chip<0>(index);
|
||||
if (update_slots_) {
|
||||
a += g.square();
|
||||
}
|
||||
v -= g.constant(lr_scalar) * g * a.rsqrt();
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
|
||||
} else {
|
||||
auto indices_vec = indices.vec<Tindex>();
|
||||
auto var_flat = var.flat<T>();
|
||||
@ -1459,19 +1475,27 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
T lr_scalar = lr.scalar<T>()();
|
||||
const Tindex first_dim_size = accum_flat.size();
|
||||
|
||||
for (Tindex i = 0; i < N; i++) {
|
||||
for (Tindex i = 0; i < N; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
T& a = accum_flat(index);
|
||||
const T& g = grad_flat(i);
|
||||
if (update_slots_) {
|
||||
a += g * g;
|
||||
}
|
||||
var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
|
||||
}
|
||||
|
||||
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
|
||||
for (Tindex i = start_idx; i < end_idx; ++i) {
|
||||
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
|
||||
T& a = accum_flat(index);
|
||||
const T& g = grad_flat(i);
|
||||
if (update_slots_) {
|
||||
a += g * g;
|
||||
}
|
||||
var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
|
||||
}
|
||||
};
|
||||
|
||||
d.parallelFor(N, cost, shard);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user