Merge pull request #33115 from Intel-tensorflow:sriniva2/scatter_opt

PiperOrigin-RevId: 274202151
This commit is contained in:
TensorFlower Gardener 2019-10-11 14:04:57 -07:00
commit 0f6190f86d
2 changed files with 137 additions and 9 deletions

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
@ -198,17 +199,58 @@ struct ScatterFunctor {
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctorBase {
Index operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
// indices and params sizes were validated in DoCompute().
Index ParallelExecute(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
const Index kMaxLocks = 1024;
const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks;
// To reduce the number of locks and the memory usage, we divide the whole
// index space into kMaxLocks regions with each lock serializing access to
// a region.
mutex accessed[kMaxLocks];
std::atomic<Index> bad_index(-1);
auto ParallelScatter = [&](Index start, Index end) {
for (Index i = start; i < end; ++i) {
// Grab the index and check its validity. Do this carefully,
// to avoid checking the value and grabbing it again from
// memory a second time (a security risk since it may change in
// between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) {
bad_index = i;
return;
}
const Index lock_id = index / entries_per_lock;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
{
mutex_lock l(accessed[lock_id]);
scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
updates.template chip<0>(i));
}
}
};
const float kMovingCost = 2.5f;
float shard_cost = kMovingCost * params.dimension(1);
const DeviceBase::CpuWorkerThreads& worker_threads =
*(c->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost,
ParallelScatter); // TODO: Come up with a good cost estimate.
return bad_index;
}
Index SerialExecute(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; ++i) {
// Grab the index and check its validity. Do this carefully,
// to avoid checking the value and grabbing it again from
// memory a second time (a security risk since it may change in between).
// memory a second time (a security risk since it may change in
// between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@ -217,6 +259,37 @@ struct ScatterFunctorBase {
}
return -1;
}
Index operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
#ifdef PLATFORM_GOOGLE
// The parallel version is significantly slower internally. Only call the
// serial version for now.
// TODO(penporn): Avoid locking in parallelization (sort beforehand).
return SerialExecute(c, d, params, updates, indices);
#else
// indices and params sizes were validated in DoCompute().
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
const Index min_n_threshold = 1024;
const Index ser_par_ratio = 10000;
// For parallelizing the updates, duplicate entries need to be handled
// correctly. Multiple updates to the same index has to be serialized.
// This can lead to lock contention which may nullify the benefits of
// parallelization. Assuming uniform random distribution of the indices, we
// come up with a rough heuristic and determine whether the updates execute
// serially or parallelly. Also if 'N' is small, overheads of parallel
// execution outweigh its benefits and hence we check the value of N.
const bool execute_serial =
((N < min_n_threshold) || ((N / limit) > ser_par_ratio));
if (execute_serial)
return SerialExecute(c, d, params, updates, indices);
else
return ParallelExecute(c, d, params, updates, indices);
#endif // PLATFORM_GOOGLE
}
};
template <typename Device, typename Index>

View File

@ -47,6 +47,17 @@ class ScatterUpdateOpTest : public OpsTestBase {
TF_ASSERT_OK(InitOp());
}
};
class ScatterSubOpTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_ref_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterSub")
.Input(FakeInput(variable_ref_type))
.Input(FakeInput(index_type))
.Input(FakeInput(RemoveRefType(variable_ref_type)))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(ScatterUpdateOpTest, Simple_StringType) {
MakeOp(DT_STRING_REF, DT_INT32);
@ -175,6 +186,37 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
<< s;
}
TEST_F(ScatterSubOpTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({14}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3}), {0, 1, 99});
AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
Status s = RunOpKernel();
EXPECT_TRUE(
absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 14)"))
<< s;
}
TEST_F(ScatterSubOpTest, StressIndexTest) {
MakeOp(DT_INT32_REF, DT_INT32);
// Feed and run
const int kRows = 1;
std::vector<int32> values(kRows, 0);
const int kNumUpdates = 1000000;
std::vector<int32> indices(kNumUpdates, 0);
std::vector<int32> updates(kNumUpdates, 1);
AddInputFromArray<int32>(TensorShape({kRows}), values);
AddInputFromArray<int32>(TensorShape({kNumUpdates}), indices);
AddInputFromArray<int32>(TensorShape({kNumUpdates}), updates);
Status s = RunOpKernel();
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_INT32, TensorShape({1}));
test::FillValues<int32>(&expected, {-1000000});
test::ExpectTensorEqual<int32>(expected, params_tensor);
}
TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
MakeOp(DT_FLOAT_REF, DT_INT32);
@ -238,7 +280,8 @@ class ScatterUpdateBM : public ScatterUpdateOpTest {
};
template <typename Index>
static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
static void BM_ScatterHelper(int iters, int embedding_size, const char* op,
bool big_num_updates = false) {
testing::StopTiming();
const int kRows = 10000000 / embedding_size;
std::vector<float> values;
@ -246,7 +289,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
for (int i = 0; i < kRows * embedding_size; i++) {
values.push_back(i);
}
const int kNumUpdates = 1000;
const int kNumUpdates = big_num_updates ? 1000000 : 1000;
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
std::vector<Index> indices;
@ -283,6 +326,10 @@ static void BM_ScatterUpdateInt64(int iters, int embedding_size) {
static void BM_ScatterAddInt32(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd");
}
static void BM_ScatterAddInt32Large(int iters, int embedding_size) {
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd", true);
}
static void BM_ScatterAddInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
}
@ -339,6 +386,14 @@ BENCHMARK(BM_ScatterUpdateInt64)
->Arg(100000);
BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterAddInt32Large)
->Arg(1)
->Arg(10)
->Arg(64)
->Arg(256)
->Arg(1024);
BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);