Merge pull request #33115 from Intel-tensorflow:sriniva2/scatter_opt
PiperOrigin-RevId: 274202151
This commit is contained in:
		
						commit
						0f6190f86d
					
				| @ -26,6 +26,7 @@ limitations under the License. | ||||
| #include "tensorflow/core/framework/variant_op_registry.h" | ||||
| #include "tensorflow/core/kernels/dense_update_functor.h" | ||||
| #include "tensorflow/core/platform/types.h" | ||||
| #include "tensorflow/core/util/work_sharder.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| @ -198,17 +199,58 @@ struct ScatterFunctor { | ||||
| 
 | ||||
| template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> | ||||
| struct ScatterFunctorBase { | ||||
|   Index operator()(OpKernelContext* c, const Device& d, | ||||
|                    typename TTypes<T>::Matrix params, | ||||
|                    typename TTypes<T>::ConstMatrix updates, | ||||
|                    typename TTypes<Index>::ConstFlat indices) { | ||||
|     // indices and params sizes were validated in DoCompute().
 | ||||
|   Index ParallelExecute(OpKernelContext* c, const Device& d, | ||||
|                         typename TTypes<T>::Matrix params, | ||||
|                         typename TTypes<T>::ConstMatrix updates, | ||||
|                         typename TTypes<Index>::ConstFlat indices) { | ||||
|     const Index N = static_cast<Index>(indices.size()); | ||||
|     const Index limit = static_cast<Index>(params.dimension(0)); | ||||
|     for (Index i = 0; i < N; i++) { | ||||
|     const Index kMaxLocks = 1024; | ||||
|     const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks; | ||||
|     // To reduce the number of locks and the memory usage, we divide the whole
 | ||||
|     // index space into kMaxLocks regions with each lock serializing access to
 | ||||
|     // a region.
 | ||||
|     mutex accessed[kMaxLocks]; | ||||
|     std::atomic<Index> bad_index(-1); | ||||
|     auto ParallelScatter = [&](Index start, Index end) { | ||||
|       for (Index i = start; i < end; ++i) { | ||||
|         // Grab the index and check its validity.  Do this carefully,
 | ||||
|         // to avoid checking the value and grabbing it again from
 | ||||
|         // memory a second time (a security risk since it may change in
 | ||||
|         // between).
 | ||||
|         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); | ||||
|         if (!FastBoundsCheck(index, limit)) { | ||||
|           bad_index = i; | ||||
|           return; | ||||
|         } | ||||
|         const Index lock_id = index / entries_per_lock; | ||||
|         // Copy last Ndim-1 dimensions of updates[i] to params[index]
 | ||||
|         { | ||||
|           mutex_lock l(accessed[lock_id]); | ||||
|           scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), | ||||
|                                                 updates.template chip<0>(i)); | ||||
|         } | ||||
|       } | ||||
|     }; | ||||
|     const float kMovingCost = 2.5f; | ||||
|     float shard_cost = kMovingCost * params.dimension(1); | ||||
|     const DeviceBase::CpuWorkerThreads& worker_threads = | ||||
|         *(c->device()->tensorflow_cpu_worker_threads()); | ||||
|     Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost, | ||||
|           ParallelScatter);  // TODO: Come up with a good cost estimate.
 | ||||
|     return bad_index; | ||||
|   } | ||||
|   Index SerialExecute(OpKernelContext* c, const Device& d, | ||||
|                       typename TTypes<T>::Matrix params, | ||||
|                       typename TTypes<T>::ConstMatrix updates, | ||||
|                       typename TTypes<Index>::ConstFlat indices) { | ||||
|     const Index N = static_cast<Index>(indices.size()); | ||||
|     const Index limit = static_cast<Index>(params.dimension(0)); | ||||
|     for (Index i = 0; i < N; ++i) { | ||||
|       // Grab the index and check its validity.  Do this carefully,
 | ||||
|       // to avoid checking the value and grabbing it again from
 | ||||
|       // memory a second time (a security risk since it may change in between).
 | ||||
|       // memory a second time (a security risk since it may change in
 | ||||
|       // between).
 | ||||
|       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); | ||||
|       if (!FastBoundsCheck(index, limit)) return i; | ||||
|       // Copy last Ndim-1 dimensions of updates[i] to params[index]
 | ||||
| @ -217,6 +259,37 @@ struct ScatterFunctorBase { | ||||
|     } | ||||
|     return -1; | ||||
|   } | ||||
| 
 | ||||
|   Index operator()(OpKernelContext* c, const Device& d, | ||||
|                    typename TTypes<T>::Matrix params, | ||||
|                    typename TTypes<T>::ConstMatrix updates, | ||||
|                    typename TTypes<Index>::ConstFlat indices) { | ||||
| #ifdef PLATFORM_GOOGLE | ||||
|     // The parallel version is significantly slower internally. Only call the
 | ||||
|     // serial version for now.
 | ||||
|     // TODO(penporn): Avoid locking in parallelization (sort beforehand).
 | ||||
|     return SerialExecute(c, d, params, updates, indices); | ||||
| #else | ||||
|     // indices and params sizes were validated in DoCompute().
 | ||||
|     const Index N = static_cast<Index>(indices.size()); | ||||
|     const Index limit = static_cast<Index>(params.dimension(0)); | ||||
|     const Index min_n_threshold = 1024; | ||||
|     const Index ser_par_ratio = 10000; | ||||
|     // For parallelizing the updates, duplicate entries need to be handled
 | ||||
|     // correctly. Multiple updates to the same index has to be serialized.
 | ||||
|     // This can lead to lock contention which may nullify the benefits of
 | ||||
|     // parallelization. Assuming uniform random distribution of the indices, we
 | ||||
|     // come up with a rough heuristic and determine whether the updates execute
 | ||||
|     // serially or parallelly. Also if 'N' is small, overheads of parallel
 | ||||
|     // execution outweigh its benefits and hence we check the value of N.
 | ||||
|     const bool execute_serial = | ||||
|         ((N < min_n_threshold) || ((N / limit) > ser_par_ratio)); | ||||
|     if (execute_serial) | ||||
|       return SerialExecute(c, d, params, updates, indices); | ||||
|     else | ||||
|       return ParallelExecute(c, d, params, updates, indices); | ||||
| #endif  // PLATFORM_GOOGLE
 | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename Device, typename Index> | ||||
|  | ||||
| @ -47,6 +47,17 @@ class ScatterUpdateOpTest : public OpsTestBase { | ||||
|     TF_ASSERT_OK(InitOp()); | ||||
|   } | ||||
| }; | ||||
| class ScatterSubOpTest : public OpsTestBase { | ||||
|  protected: | ||||
|   void MakeOp(DataType variable_ref_type, DataType index_type) { | ||||
|     TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterSub") | ||||
|                      .Input(FakeInput(variable_ref_type)) | ||||
|                      .Input(FakeInput(index_type)) | ||||
|                      .Input(FakeInput(RemoveRefType(variable_ref_type))) | ||||
|                      .Finalize(node_def())); | ||||
|     TF_ASSERT_OK(InitOp()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| TEST_F(ScatterUpdateOpTest, Simple_StringType) { | ||||
|   MakeOp(DT_STRING_REF, DT_INT32); | ||||
| @ -175,6 +186,37 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { | ||||
|       << s; | ||||
| } | ||||
| 
 | ||||
| TEST_F(ScatterSubOpTest, Error_IndexOutOfRange) { | ||||
|   MakeOp(DT_FLOAT_REF, DT_INT32); | ||||
|   // Feed and run
 | ||||
|   AddInputFromArray<float>(TensorShape({14}), | ||||
|                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); | ||||
|   AddInputFromArray<int32>(TensorShape({3}), {0, 1, 99}); | ||||
|   AddInputFromArray<float>(TensorShape({3}), {100, 101, 102}); | ||||
|   Status s = RunOpKernel(); | ||||
|   EXPECT_TRUE( | ||||
|       absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 14)")) | ||||
|       << s; | ||||
| } | ||||
| 
 | ||||
| TEST_F(ScatterSubOpTest, StressIndexTest) { | ||||
|   MakeOp(DT_INT32_REF, DT_INT32); | ||||
|   // Feed and run
 | ||||
|   const int kRows = 1; | ||||
|   std::vector<int32> values(kRows, 0); | ||||
|   const int kNumUpdates = 1000000; | ||||
|   std::vector<int32> indices(kNumUpdates, 0); | ||||
|   std::vector<int32> updates(kNumUpdates, 1); | ||||
|   AddInputFromArray<int32>(TensorShape({kRows}), values); | ||||
|   AddInputFromArray<int32>(TensorShape({kNumUpdates}), indices); | ||||
|   AddInputFromArray<int32>(TensorShape({kNumUpdates}), updates); | ||||
|   Status s = RunOpKernel(); | ||||
|   Tensor params_tensor = *mutable_input(0).tensor; | ||||
|   Tensor expected(allocator(), DT_INT32, TensorShape({1})); | ||||
|   test::FillValues<int32>(&expected, {-1000000}); | ||||
|   test::ExpectTensorEqual<int32>(expected, params_tensor); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { | ||||
|   MakeOp(DT_FLOAT_REF, DT_INT32); | ||||
| 
 | ||||
| @ -238,7 +280,8 @@ class ScatterUpdateBM : public ScatterUpdateOpTest { | ||||
| }; | ||||
| 
 | ||||
| template <typename Index> | ||||
| static void BM_ScatterHelper(int iters, int embedding_size, const char* op) { | ||||
| static void BM_ScatterHelper(int iters, int embedding_size, const char* op, | ||||
|                              bool big_num_updates = false) { | ||||
|   testing::StopTiming(); | ||||
|   const int kRows = 10000000 / embedding_size; | ||||
|   std::vector<float> values; | ||||
| @ -246,7 +289,7 @@ static void BM_ScatterHelper(int iters, int embedding_size, const char* op) { | ||||
|   for (int i = 0; i < kRows * embedding_size; i++) { | ||||
|     values.push_back(i); | ||||
|   } | ||||
|   const int kNumUpdates = 1000; | ||||
|   const int kNumUpdates = big_num_updates ? 1000000 : 1000; | ||||
|   random::PhiloxRandom philox(301, 17); | ||||
|   random::SimplePhilox rnd(&philox); | ||||
|   std::vector<Index> indices; | ||||
| @ -283,6 +326,10 @@ static void BM_ScatterUpdateInt64(int iters, int embedding_size) { | ||||
| static void BM_ScatterAddInt32(int iters, int embedding_size) { | ||||
|   BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd"); | ||||
| } | ||||
| 
 | ||||
| static void BM_ScatterAddInt32Large(int iters, int embedding_size) { | ||||
|   BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd", true); | ||||
| } | ||||
| static void BM_ScatterAddInt64(int iters, int embedding_size) { | ||||
|   BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd"); | ||||
| } | ||||
| @ -339,6 +386,14 @@ BENCHMARK(BM_ScatterUpdateInt64) | ||||
|     ->Arg(100000); | ||||
| 
 | ||||
| BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); | ||||
| 
 | ||||
| BENCHMARK(BM_ScatterAddInt32Large) | ||||
|     ->Arg(1) | ||||
|     ->Arg(10) | ||||
|     ->Arg(64) | ||||
|     ->Arg(256) | ||||
|     ->Arg(1024); | ||||
| 
 | ||||
| BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); | ||||
| 
 | ||||
| BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user