diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f0cb90053e4..02ce47e364f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -379,8 +379,8 @@ tf_kernel_libraries( "batch_matrix_diag_op", "batch_matrix_set_diag_op", "edit_distance_op", - "gather_nd_op", "gather_op", + "gather_nd_op", "identity_op", "immutable_constant_op", "listdiff_op", diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index b4d9f03efc6..c2a5192efb1 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -16,13 +16,11 @@ limitations under the License. // See docs in ../ops/array_ops.cc. #define EIGEN_USE_THREADS -#include - +#include "tensorflow/core/kernels/gather_nd_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/gather_nd_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" @@ -155,97 +153,6 @@ class GatherNdOp : public OpKernel { } }; -// Specialization of GatherNdSlice to CPU -namespace generator { - -template -class GatherNdSliceGenerator { - public: - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( - const Index slice_size, typename TTypes::ConstMatrix Tindices, - typename TTypes::ConstTensor Tparams, - typename TTypes::Matrix Tout, std::atomic* error_loc) - : slice_size_(slice_size), - Tindices_(Tindices), - Tparams_(Tparams), - Tout_(Tout), - error_loc_(error_loc) {} - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( - const Index loc, Eigen::array* ix) const { - (*ix)[IXDIM] = 0; - bool out_of_bounds = false; - for (int i = 0; i < IXDIM; ++i) { - const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); - (*ix)[i] = ix_i; - out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); - } - return out_of_bounds; - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 - operator()(const Eigen::array& loc_array) const { - const Index loc = loc_array[0]; - Eigen::array ix; - Eigen::array ix_out; - ix_out[0] = loc; - ix_out[1] = 0; - const bool out_of_bounds = GenerateIndices(loc, &ix); - if (TF_PREDICT_FALSE(out_of_bounds)) { - error_loc_->store(loc); - std::fill_n(&Tout_(ix_out), slice_size_, T()); - } else { - std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); - } - - return static_cast(0); // Return something... - } - - private: - const Index slice_size_; - const typename TTypes::ConstMatrix Tindices_; - const typename TTypes::ConstTensor Tparams_; - mutable typename TTypes::Matrix Tout_; - std::atomic* error_loc_; -}; - -} // namespace generator - -namespace functor { - -template -struct GatherNdSlice { - Index operator()(const CPUDevice& d, const Index slice_size, - typename TTypes::Scalar Tscratch, - typename TTypes::ConstTensor Tparams, - typename TTypes::ConstMatrix Tindices, - typename TTypes::Matrix Tout) { - std::atomic error_loc(-1); - - const Eigen::DenseIndex batch_size = Tindices.dimension(0); -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::Tensor::Dimensions reshape_dims{{ 1 }}; - Eigen::array broadcast_dims{{ batch_size }}; -#else - Eigen::IndexList > reshape_dims; - Eigen::IndexList broadcast_dims; - broadcast_dims.set(0, batch_size); -#endif - generator::GatherNdSliceGenerator gather_nd_generator( - slice_size, Tindices, Tparams, Tout, &error_loc); - Tscratch.device(d) = Tscratch.reshape(reshape_dims) - .broadcast(broadcast_dims) - .generate(gather_nd_generator) - .sum(); - - // error_loc() returns -1 if there's no out-of-bounds index, - // otherwise it returns the location of an OOB index in Tindices. - return error_loc.load(); - } -}; - -} // namespace functor - #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \ REGISTER_KERNEL_BUILDER(Name("GatherNd") \ .Device(DEVICE_##dev) \ diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h index 0ee783bd593..d7279d5712a 100644 --- a/tensorflow/core/kernels/gather_nd_op.h +++ b/tensorflow/core/kernels/gather_nd_op.h @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h new file mode 100644 index 00000000000..dc028c2f1e9 --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -0,0 +1,145 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ +#define TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ + +// Specialization of GatherNdSlice to CPU + +#define EIGEN_USE_THREADS + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/gather_nd_op.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +namespace generator { + +template +class GatherNdSliceGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( + const Index slice_size, typename TTypes::ConstMatrix Tindices, + typename TTypes::ConstTensor Tparams, + typename TTypes::Matrix Tout, std::atomic* error_loc) + : slice_size_(slice_size), + Tindices_(Tindices), + Tparams_(Tparams), + Tout_(Tout), + error_loc_(error_loc) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( + const Index loc, Eigen::array* ix) const { + (*ix)[IXDIM] = 0; + bool out_of_bounds = false; + for (int i = 0; i < IXDIM; ++i) { + const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); + (*ix)[i] = ix_i; + out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); + } + return out_of_bounds; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 + operator()(const Eigen::array& loc_array) const { + const Index loc = loc_array[0]; + Eigen::array ix; + Eigen::array ix_out; + ix_out[0] = loc; + ix_out[1] = 0; + const bool out_of_bounds = GenerateIndices(loc, &ix); + if (TF_PREDICT_FALSE(out_of_bounds)) { + error_loc_->store(loc); + std::fill_n(&Tout_(ix_out), slice_size_, T()); + } else { + std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); + } + + return static_cast(0); // Return something... + } + + private: + const Index slice_size_; + const typename TTypes::ConstMatrix Tindices_; + const typename TTypes::ConstTensor Tparams_; + mutable typename TTypes::Matrix Tout_; + std::atomic* error_loc_; +}; + +} // namespace generator + +namespace functor { + +template +struct GatherNdSlice { + Index operator()(const CPUDevice& d, const Index slice_size, + typename TTypes::Scalar Tscratch, + typename TTypes::ConstTensor Tparams, + typename TTypes::ConstMatrix Tindices, + typename TTypes::Matrix Tout) { + std::atomic error_loc(-1); + + const Eigen::DenseIndex batch_size = Tindices.dimension(0); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::Tensor::Dimensions reshape_dims{{ 1 }}; + Eigen::array broadcast_dims{{ batch_size }}; +#else + Eigen::IndexList > reshape_dims; + Eigen::IndexList broadcast_dims; + broadcast_dims.set(0, batch_size); +#endif + generator::GatherNdSliceGenerator gather_nd_generator( + slice_size, Tindices, Tparams, Tout, &error_loc); + Tscratch.device(d) = Tscratch.reshape(reshape_dims) + .broadcast(broadcast_dims) + .generate(gather_nd_generator) + .sum(); + + // error_loc() returns -1 if there's no out-of-bounds index, + // otherwise it returns the location of an OOB index in Tindices. + return error_loc.load(); + } +}; + +#define REGISTER_GATHER_ND_FULL(T, Index) \ + template Index GatherNdSlice:: \ + operator()(const CPUDevice& d, const Index slice_size, \ + typename TTypes::Scalar Tscratch, \ + typename TTypes::ConstTensor Tparams, \ + typename TTypes::ConstMatrix Tindices, \ + typename TTypes::Matrix Tout); + +#define REGISTER_GATHER_ND_CPU(type) \ + REGISTER_GATHER_ND_FULL(type, int32); \ + REGISTER_GATHER_ND_FULL(type, int64) + +TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc b/tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc new file mode 100644 index 00000000000..246e9f729b8 --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc @@ -0,0 +1,18 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define CPU_PROVIDED_IXDIM 0 +#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h" +#undef CPU_PROVIDED_IXDIM diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc b/tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc new file mode 100644 index 00000000000..5b7720fc4ef --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc @@ -0,0 +1,18 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define CPU_PROVIDED_IXDIM 1 +#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h" +#undef CPU_PROVIDED_IXDIM diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc b/tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc new file mode 100644 index 00000000000..0f6932394ed --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc @@ -0,0 +1,18 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define CPU_PROVIDED_IXDIM 2 +#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h" +#undef CPU_PROVIDED_IXDIM diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc b/tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc new file mode 100644 index 00000000000..1c2aec7820a --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc @@ -0,0 +1,18 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define CPU_PROVIDED_IXDIM 3 +#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h" +#undef CPU_PROVIDED_IXDIM diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc b/tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc new file mode 100644 index 00000000000..3e164668c5b --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc @@ -0,0 +1,18 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define CPU_PROVIDED_IXDIM 4 +#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h" +#undef CPU_PROVIDED_IXDIM diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc b/tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc new file mode 100644 index 00000000000..7141ea70df9 --- /dev/null +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc @@ -0,0 +1,18 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define CPU_PROVIDED_IXDIM 5 +#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h" +#undef CPU_PROVIDED_IXDIM