Split gather_nd CPU functors into multiple files for faster compile times.
Change: 129153756
This commit is contained in:
parent
21ca7e442c
commit
d3492d2a21
@ -379,8 +379,8 @@ tf_kernel_libraries(
|
||||
"batch_matrix_diag_op",
|
||||
"batch_matrix_set_diag_op",
|
||||
"edit_distance_op",
|
||||
"gather_nd_op",
|
||||
"gather_op",
|
||||
"gather_nd_op",
|
||||
"identity_op",
|
||||
"immutable_constant_op",
|
||||
"listdiff_op",
|
||||
|
@ -16,13 +16,11 @@ limitations under the License.
|
||||
// See docs in ../ops/array_ops.cc.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -155,97 +153,6 @@ class GatherNdOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization of GatherNdSlice to CPU
|
||||
namespace generator {
|
||||
|
||||
template <typename T, typename Index, int IXDIM>
|
||||
class GatherNdSliceGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator(
|
||||
const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
|
||||
typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc)
|
||||
: slice_size_(slice_size),
|
||||
Tindices_(Tindices),
|
||||
Tparams_(Tparams),
|
||||
Tout_(Tout),
|
||||
error_loc_(error_loc) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices(
|
||||
const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
|
||||
(*ix)[IXDIM] = 0;
|
||||
bool out_of_bounds = false;
|
||||
for (int i = 0; i < IXDIM; ++i) {
|
||||
const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
|
||||
(*ix)[i] = ix_i;
|
||||
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||
}
|
||||
return out_of_bounds;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
|
||||
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||
const Index loc = loc_array[0];
|
||||
Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix;
|
||||
Eigen::array<Eigen::DenseIndex, 2> ix_out;
|
||||
ix_out[0] = loc;
|
||||
ix_out[1] = 0;
|
||||
const bool out_of_bounds = GenerateIndices(loc, &ix);
|
||||
if (TF_PREDICT_FALSE(out_of_bounds)) {
|
||||
error_loc_->store(loc);
|
||||
std::fill_n(&Tout_(ix_out), slice_size_, T());
|
||||
} else {
|
||||
std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out));
|
||||
}
|
||||
|
||||
return static_cast<int32>(0); // Return something...
|
||||
}
|
||||
|
||||
private:
|
||||
const Index slice_size_;
|
||||
const typename TTypes<Index>::ConstMatrix Tindices_;
|
||||
const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_;
|
||||
mutable typename TTypes<T>::Matrix Tout_;
|
||||
std::atomic<Index>* error_loc_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T, typename Index, int IXDIM>
|
||||
struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
|
||||
Index operator()(const CPUDevice& d, const Index slice_size,
|
||||
typename TTypes<int32>::Scalar Tscratch,
|
||||
typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
|
||||
typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T>::Matrix Tout) {
|
||||
std::atomic<Index> error_loc(-1);
|
||||
|
||||
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
|
||||
Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
|
||||
Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
|
||||
broadcast_dims.set(0, batch_size);
|
||||
#endif
|
||||
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
|
||||
slice_size, Tindices, Tparams, Tout, &error_loc);
|
||||
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
|
||||
.broadcast(broadcast_dims)
|
||||
.generate(gather_nd_generator)
|
||||
.sum();
|
||||
|
||||
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||
// otherwise it returns the location of an OOB index in Tindices.
|
||||
return error_loc.load();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("GatherNd") \
|
||||
.Device(DEVICE_##dev) \
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
145
tensorflow/core/kernels/gather_nd_op_cpu_impl.h
Normal file
145
tensorflow/core/kernels/gather_nd_op_cpu_impl.h
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
|
||||
#define TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
|
||||
|
||||
// Specialization of GatherNdSlice to CPU
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
namespace generator {
|
||||
|
||||
template <typename T, typename Index, int IXDIM>
|
||||
class GatherNdSliceGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator(
|
||||
const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
|
||||
typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc)
|
||||
: slice_size_(slice_size),
|
||||
Tindices_(Tindices),
|
||||
Tparams_(Tparams),
|
||||
Tout_(Tout),
|
||||
error_loc_(error_loc) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices(
|
||||
const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
|
||||
(*ix)[IXDIM] = 0;
|
||||
bool out_of_bounds = false;
|
||||
for (int i = 0; i < IXDIM; ++i) {
|
||||
const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
|
||||
(*ix)[i] = ix_i;
|
||||
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||
}
|
||||
return out_of_bounds;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
|
||||
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||
const Index loc = loc_array[0];
|
||||
Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix;
|
||||
Eigen::array<Eigen::DenseIndex, 2> ix_out;
|
||||
ix_out[0] = loc;
|
||||
ix_out[1] = 0;
|
||||
const bool out_of_bounds = GenerateIndices(loc, &ix);
|
||||
if (TF_PREDICT_FALSE(out_of_bounds)) {
|
||||
error_loc_->store(loc);
|
||||
std::fill_n(&Tout_(ix_out), slice_size_, T());
|
||||
} else {
|
||||
std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out));
|
||||
}
|
||||
|
||||
return static_cast<int32>(0); // Return something...
|
||||
}
|
||||
|
||||
private:
|
||||
const Index slice_size_;
|
||||
const typename TTypes<Index>::ConstMatrix Tindices_;
|
||||
const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_;
|
||||
mutable typename TTypes<T>::Matrix Tout_;
|
||||
std::atomic<Index>* error_loc_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T, typename Index, int IXDIM>
|
||||
struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
|
||||
Index operator()(const CPUDevice& d, const Index slice_size,
|
||||
typename TTypes<int32>::Scalar Tscratch,
|
||||
typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
|
||||
typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T>::Matrix Tout) {
|
||||
std::atomic<Index> error_loc(-1);
|
||||
|
||||
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
|
||||
Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
|
||||
Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
|
||||
broadcast_dims.set(0, batch_size);
|
||||
#endif
|
||||
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
|
||||
slice_size, Tindices, Tparams, Tout, &error_loc);
|
||||
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
|
||||
.broadcast(broadcast_dims)
|
||||
.generate(gather_nd_generator)
|
||||
.sum();
|
||||
|
||||
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||
// otherwise it returns the location of an OOB index in Tindices.
|
||||
return error_loc.load();
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_GATHER_ND_FULL(T, Index) \
|
||||
template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>:: \
|
||||
operator()(const CPUDevice& d, const Index slice_size, \
|
||||
typename TTypes<int32>::Scalar Tscratch, \
|
||||
typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \
|
||||
typename TTypes<Index>::ConstMatrix Tindices, \
|
||||
typename TTypes<T>::Matrix Tout);
|
||||
|
||||
#define REGISTER_GATHER_ND_CPU(type) \
|
||||
REGISTER_GATHER_ND_FULL(type, int32); \
|
||||
REGISTER_GATHER_ND_FULL(type, int64)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
|
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
Normal file
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 0
|
||||
#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
Normal file
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 1
|
||||
#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
Normal file
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 2
|
||||
#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
Normal file
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 3
|
||||
#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
Normal file
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 4
|
||||
#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
Normal file
18
tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 5
|
||||
#include "tensorflow/core/kernels/gather_nd_op_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
Loading…
x
Reference in New Issue
Block a user