Parallelize SparseApplyAdagradOp

PiperOrigin-RevId: 254106236
This commit is contained in:
Eugene Zhulenev 2019-06-19 17:36:45 -07:00 committed by TensorFlower Gardener
parent 223d17868c
commit c174697c09

View File

@ -1426,7 +1426,16 @@ class SparseApplyAdagradOp : public OpKernel {
errors::InvalidArgument(
"Inner dimension should be greater than zero."));
// This op is implemented only for CPU device.
const auto& d = ctx->eigen_cpu_device();
if (N > 0) {
const int in_bytes = inner_dim * sizeof(T) * 3;
const int out_bytes = inner_dim * sizeof(T) * 2;
const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
Eigen::TensorOpCost::MulCost<T>() * 2);
const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
if (inner_dim > 1) {
const Tindex first_dim_size = var.dim_size(0);
auto indices_vec = indices.vec<Tindex>();
@ -1435,22 +1444,29 @@ class SparseApplyAdagradOp : public OpKernel {
auto grad_flat = grad.flat_outer_dims<T>();
T lr_scalar = lr.scalar<T>()();
// Note(yonghui): It might be worth multi-threading square() and
// rsqrt().
for (Tindex i = 0; i < N; i++) {
for (Tindex i = 0; i < N; ++i) {
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range")));
auto a = accum_flat.template chip<0>(index);
auto g = grad_flat.template chip<0>(i);
auto v = var_flat.template chip<0>(index);
if (update_slots_) {
a += g.square();
}
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
for (Tindex i = start_idx; i < end_idx; ++i) {
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
auto a = accum_flat.template chip<0>(index);
auto g = grad_flat.template chip<0>(i);
auto v = var_flat.template chip<0>(index);
if (update_slots_) {
a += g.square();
}
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
};
d.parallelFor(N, cost, shard);
} else {
auto indices_vec = indices.vec<Tindex>();
auto var_flat = var.flat<T>();
@ -1459,19 +1475,27 @@ class SparseApplyAdagradOp : public OpKernel {
T lr_scalar = lr.scalar<T>()();
const Tindex first_dim_size = accum_flat.size();
for (Tindex i = 0; i < N; i++) {
for (Tindex i = 0; i < N; ++i) {
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range")));
T& a = accum_flat(index);
const T& g = grad_flat(i);
if (update_slots_) {
a += g * g;
}
var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
}
const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
for (Tindex i = start_idx; i < end_idx; ++i) {
const Tindex index = internal::SubtleMustCopy(indices_vec(i));
T& a = accum_flat(index);
const T& g = grad_flat(i);
if (update_slots_) {
a += g * g;
}
var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
}
};
d.parallelFor(N, cost, shard);
}
}