Merge pull request #33115 from Intel-tensorflow:sriniva2/scatter_opt
PiperOrigin-RevId: 274202151
This commit is contained in:
commit
0f6190f86d
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -198,17 +199,58 @@ struct ScatterFunctor {
|
||||
|
||||
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
|
||||
struct ScatterFunctorBase {
|
||||
Index operator()(OpKernelContext* c, const Device& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
// indices and params sizes were validated in DoCompute().
|
||||
Index ParallelExecute(OpKernelContext* c, const Device& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
const Index N = static_cast<Index>(indices.size());
|
||||
const Index limit = static_cast<Index>(params.dimension(0));
|
||||
for (Index i = 0; i < N; i++) {
|
||||
const Index kMaxLocks = 1024;
|
||||
const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks;
|
||||
// To reduce the number of locks and the memory usage, we divide the whole
|
||||
// index space into kMaxLocks regions with each lock serializing access to
|
||||
// a region.
|
||||
mutex accessed[kMaxLocks];
|
||||
std::atomic<Index> bad_index(-1);
|
||||
auto ParallelScatter = [&](Index start, Index end) {
|
||||
for (Index i = start; i < end; ++i) {
|
||||
// Grab the index and check its validity. Do this carefully,
|
||||
// to avoid checking the value and grabbing it again from
|
||||
// memory a second time (a security risk since it may change in
|
||||
// between).
|
||||
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, limit)) {
|
||||
bad_index = i;
|
||||
return;
|
||||
}
|
||||
const Index lock_id = index / entries_per_lock;
|
||||
// Copy last Ndim-1 dimensions of updates[i] to params[index]
|
||||
{
|
||||
mutex_lock l(accessed[lock_id]);
|
||||
scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
|
||||
updates.template chip<0>(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
const float kMovingCost = 2.5f;
|
||||
float shard_cost = kMovingCost * params.dimension(1);
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*(c->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost,
|
||||
ParallelScatter); // TODO: Come up with a good cost estimate.
|
||||
return bad_index;
|
||||
}
|
||||
Index SerialExecute(OpKernelContext* c, const Device& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
const Index N = static_cast<Index>(indices.size());
|
||||
const Index limit = static_cast<Index>(params.dimension(0));
|
||||
for (Index i = 0; i < N; ++i) {
|
||||
// Grab the index and check its validity. Do this carefully,
|
||||
// to avoid checking the value and grabbing it again from
|
||||
// memory a second time (a security risk since it may change in between).
|
||||
// memory a second time (a security risk since it may change in
|
||||
// between).
|
||||
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, limit)) return i;
|
||||
// Copy last Ndim-1 dimensions of updates[i] to params[index]
|
||||
@ -217,6 +259,37 @@ struct ScatterFunctorBase {
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
Index operator()(OpKernelContext* c, const Device& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// The parallel version is significantly slower internally. Only call the
|
||||
// serial version for now.
|
||||
// TODO(penporn): Avoid locking in parallelization (sort beforehand).
|
||||
return SerialExecute(c, d, params, updates, indices);
|
||||
#else
|
||||
// indices and params sizes were validated in DoCompute().
|
||||
const Index N = static_cast<Index>(indices.size());
|
||||
const Index limit = static_cast<Index>(params.dimension(0));
|
||||
const Index min_n_threshold = 1024;
|
||||
const Index ser_par_ratio = 10000;
|
||||
// For parallelizing the updates, duplicate entries need to be handled
|
||||
// correctly. Multiple updates to the same index has to be serialized.
|
||||
// This can lead to lock contention which may nullify the benefits of
|
||||
// parallelization. Assuming uniform random distribution of the indices, we
|
||||
// come up with a rough heuristic and determine whether the updates execute
|
||||
// serially or parallelly. Also if 'N' is small, overheads of parallel
|
||||
// execution outweigh its benefits and hence we check the value of N.
|
||||
const bool execute_serial =
|
||||
((N < min_n_threshold) || ((N / limit) > ser_par_ratio));
|
||||
if (execute_serial)
|
||||
return SerialExecute(c, d, params, updates, indices);
|
||||
else
|
||||
return ParallelExecute(c, d, params, updates, indices);
|
||||
#endif // PLATFORM_GOOGLE
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename Index>
|
||||
|
||||
@ -47,6 +47,17 @@ class ScatterUpdateOpTest : public OpsTestBase {
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
class ScatterSubOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_ref_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterSub")
|
||||
.Input(FakeInput(variable_ref_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(RemoveRefType(variable_ref_type)))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ScatterUpdateOpTest, Simple_StringType) {
|
||||
MakeOp(DT_STRING_REF, DT_INT32);
|
||||
@ -175,6 +186,37 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(ScatterSubOpTest, Error_IndexOutOfRange) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({14}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3}), {0, 1, 99});
|
||||
AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 14)"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(ScatterSubOpTest, StressIndexTest) {
|
||||
MakeOp(DT_INT32_REF, DT_INT32);
|
||||
// Feed and run
|
||||
const int kRows = 1;
|
||||
std::vector<int32> values(kRows, 0);
|
||||
const int kNumUpdates = 1000000;
|
||||
std::vector<int32> indices(kNumUpdates, 0);
|
||||
std::vector<int32> updates(kNumUpdates, 1);
|
||||
AddInputFromArray<int32>(TensorShape({kRows}), values);
|
||||
AddInputFromArray<int32>(TensorShape({kNumUpdates}), indices);
|
||||
AddInputFromArray<int32>(TensorShape({kNumUpdates}), updates);
|
||||
Status s = RunOpKernel();
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({1}));
|
||||
test::FillValues<int32>(&expected, {-1000000});
|
||||
test::ExpectTensorEqual<int32>(expected, params_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
@ -238,7 +280,8 @@ class ScatterUpdateBM : public ScatterUpdateOpTest {
|
||||
};
|
||||
|
||||
template <typename Index>
|
||||
static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
|
||||
static void BM_ScatterHelper(int iters, int embedding_size, const char* op,
|
||||
bool big_num_updates = false) {
|
||||
testing::StopTiming();
|
||||
const int kRows = 10000000 / embedding_size;
|
||||
std::vector<float> values;
|
||||
@ -246,7 +289,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
|
||||
for (int i = 0; i < kRows * embedding_size; i++) {
|
||||
values.push_back(i);
|
||||
}
|
||||
const int kNumUpdates = 1000;
|
||||
const int kNumUpdates = big_num_updates ? 1000000 : 1000;
|
||||
random::PhiloxRandom philox(301, 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
std::vector<Index> indices;
|
||||
@ -283,6 +326,10 @@ static void BM_ScatterUpdateInt64(int iters, int embedding_size) {
|
||||
static void BM_ScatterAddInt32(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd");
|
||||
}
|
||||
|
||||
static void BM_ScatterAddInt32Large(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd", true);
|
||||
}
|
||||
static void BM_ScatterAddInt64(int iters, int embedding_size) {
|
||||
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
|
||||
}
|
||||
@ -339,6 +386,14 @@ BENCHMARK(BM_ScatterUpdateInt64)
|
||||
->Arg(100000);
|
||||
|
||||
BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||
|
||||
BENCHMARK(BM_ScatterAddInt32Large)
|
||||
->Arg(1)
|
||||
->Arg(10)
|
||||
->Arg(64)
|
||||
->Arg(256)
|
||||
->Arg(1024);
|
||||
|
||||
BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||
|
||||
BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user