Created a Gather functor that works with batch dimensions.
PiperOrigin-RevId: 261887312
This commit is contained in:
parent
ee202801b3
commit
ecc0b95092
@ -129,6 +129,7 @@ tensorflow/core/kernels/function_ops.cc
|
|||||||
tensorflow/core/kernels/fused_batch_norm_op.cc
|
tensorflow/core/kernels/fused_batch_norm_op.cc
|
||||||
tensorflow/core/kernels/fused_eigen_output_kernels.cc
|
tensorflow/core/kernels/fused_eigen_output_kernels.cc
|
||||||
tensorflow/core/kernels/gather_functor.cc
|
tensorflow/core/kernels/gather_functor.cc
|
||||||
|
tensorflow/core/kernels/gather_functor_batched.cc
|
||||||
tensorflow/core/kernels/gather_nd_op.cc
|
tensorflow/core/kernels/gather_nd_op.cc
|
||||||
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
|
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
|
||||||
tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
|
tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
|
||||||
|
@ -1887,7 +1887,10 @@ tf_kernel_library(
|
|||||||
# Unlike gather_functor library, this does not include the CUDA code and deps.
|
# Unlike gather_functor library, this does not include the CUDA code and deps.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gather_functor_hdr",
|
name = "gather_functor_hdr",
|
||||||
hdrs = ["gather_functor.h"],
|
hdrs = [
|
||||||
|
"gather_functor.h",
|
||||||
|
"gather_functor_batched.h",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
@ -6031,6 +6034,7 @@ filegroup(
|
|||||||
"function_ops.cc",
|
"function_ops.cc",
|
||||||
"function_ops.h",
|
"function_ops.h",
|
||||||
"gather_functor.h",
|
"gather_functor.h",
|
||||||
|
"gather_functor_batched.h",
|
||||||
"gather_nd_op.cc",
|
"gather_nd_op.cc",
|
||||||
"gather_nd_op.h",
|
"gather_nd_op.h",
|
||||||
"gather_nd_op_cpu_impl.h",
|
"gather_nd_op_cpu_impl.h",
|
||||||
|
55
tensorflow/core/kernels/gather_functor_batched.cc
Normal file
55
tensorflow/core/kernels/gather_functor_batched.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||||
|
template <> \
|
||||||
|
int64 GatherFunctorBatched<GPUDevice, T, Index>::operator()( \
|
||||||
|
OpKernelContext* ctx, typename TTypes<T, 4>::ConstTensor Tparams, \
|
||||||
|
typename TTypes<Index>::ConstFlat Tindices, \
|
||||||
|
typename TTypes<T, 4>::Tensor Tout); \
|
||||||
|
extern template struct GatherFunctorBatched<GPUDevice, T, Index>;
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
|
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX(T, int64)
|
||||||
|
|
||||||
|
TF_CALL_int64(DECLARE_GPU_SPECS);
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||||
|
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||||
|
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPECS
|
||||||
|
#undef DECLARE_GPU_SPECS_INDEX
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
197
tensorflow/core/kernels/gather_functor_batched.h
Normal file
197
tensorflow/core/kernels/gather_functor_batched.h
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
|
#include "tensorflow/core/framework/variant.h"
|
||||||
|
#include "tensorflow/core/platform/prefetch.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
// Helper method to copy using memcpy.
|
||||||
|
template <typename T, typename Index, typename SliceIndex,
|
||||||
|
SliceIndex static_slice_elems>
|
||||||
|
SliceIndex HandleCopiesBatched(OpKernelContext* ctx,
|
||||||
|
typename TTypes<T, 4>::ConstTensor params,
|
||||||
|
typename TTypes<Index>::ConstFlat indices,
|
||||||
|
SliceIndex slice_elems,
|
||||||
|
typename TTypes<T, 4>::Tensor out) {
|
||||||
|
const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
|
||||||
|
const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1));
|
||||||
|
const SliceIndex indices_size =
|
||||||
|
static_cast<SliceIndex>(indices.dimension(0)) / batch_size;
|
||||||
|
|
||||||
|
const Index limit = static_cast<Index>(params.dimension(2));
|
||||||
|
if (static_slice_elems >= 0) {
|
||||||
|
// Give compiler static knowledge of the number of elements/bytes
|
||||||
|
slice_elems = static_slice_elems;
|
||||||
|
}
|
||||||
|
// Compute slice_bytes here so that static knowledge is available
|
||||||
|
const size_t slice_bytes = slice_elems * sizeof(T);
|
||||||
|
auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
|
||||||
|
mutex mu;
|
||||||
|
// Store the value of invalidate index for printing error information, it's a
|
||||||
|
// shared variable.
|
||||||
|
SliceIndex result = -1;
|
||||||
|
auto work = [&](int64 start, int64 end) {
|
||||||
|
const int64 r_start = start % (outer_size * indices_size);
|
||||||
|
SliceIndex batch_idx = static_cast<SliceIndex>(
|
||||||
|
start / (outer_size * indices_size));
|
||||||
|
SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size);
|
||||||
|
SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size);
|
||||||
|
|
||||||
|
SliceIndex batch_offset = batch_idx * indices_size;
|
||||||
|
for (; start < end; ++start) {
|
||||||
|
SliceIndex i_next = indices_idx + 1;
|
||||||
|
SliceIndex o_next = outer_idx;
|
||||||
|
SliceIndex b_next = batch_idx;
|
||||||
|
SliceIndex b_offset_next = batch_offset;
|
||||||
|
|
||||||
|
if (i_next >= indices_size) {
|
||||||
|
i_next = 0;
|
||||||
|
if (++o_next >= outer_size) {
|
||||||
|
o_next = 0;
|
||||||
|
++b_next;
|
||||||
|
b_offset_next += indices_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (start + 1 < end) {
|
||||||
|
port::prefetch<port::PREFETCH_HINT_T0>(
|
||||||
|
¶ms(b_next, o_next, indices(b_offset_next + i_next), 0));
|
||||||
|
port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0));
|
||||||
|
}
|
||||||
|
const Index index = internal::SubtleMustCopy(
|
||||||
|
indices(batch_offset + indices_idx));
|
||||||
|
if (!FastBoundsCheck(index, limit)) {
|
||||||
|
mutex_lock l(mu);
|
||||||
|
result = batch_offset + indices_idx;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy using memcpy if possible, otherwise an Eigen loop
|
||||||
|
// TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
|
||||||
|
// ahead-of-time compilation binary size).
|
||||||
|
if (is_simple_type<T>::value) {
|
||||||
|
// Avoid auto-promotion to Index from SliceIndex by casting.
|
||||||
|
memcpy(
|
||||||
|
&out(batch_idx, outer_idx, indices_idx, 0),
|
||||||
|
¶ms(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0),
|
||||||
|
slice_bytes);
|
||||||
|
} else {
|
||||||
|
// For non-"simple" types (e.g. strings).
|
||||||
|
out.template chip<2>(indices_idx) = params.template chip<2>(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
indices_idx = i_next;
|
||||||
|
outer_idx = o_next;
|
||||||
|
batch_idx = b_next;
|
||||||
|
batch_offset = b_offset_next;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Shard(worker_threads->num_threads, worker_threads->workers,
|
||||||
|
batch_size * outer_size * indices_size, slice_elems * sizeof(T), work);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
struct GatherFunctorBatchedCPU {
|
||||||
|
int64 operator()(OpKernelContext* ctx,
|
||||||
|
typename TTypes<T, 4>::ConstTensor params,
|
||||||
|
typename TTypes<Index>::ConstFlat indices,
|
||||||
|
typename TTypes<T, 4>::Tensor out) {
|
||||||
|
const int64 indices_size = indices.size(); // Includes the batch_size.
|
||||||
|
const int64 slice_size = out.dimension(3);
|
||||||
|
int64 bad_i;
|
||||||
|
|
||||||
|
const int64 batch_size = params.dimension(0);
|
||||||
|
const int64 outer_size = params.dimension(1);
|
||||||
|
|
||||||
|
bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
|
||||||
|
params.size() > std::numeric_limits<int32>::max() ||
|
||||||
|
indices_size > std::numeric_limits<int32>::max() ||
|
||||||
|
batch_size * outer_size * indices_size * slice_size >
|
||||||
|
std::numeric_limits<int32>::max());
|
||||||
|
#define CALL(elems) \
|
||||||
|
do { \
|
||||||
|
if (use_large) { \
|
||||||
|
bad_i = HandleCopiesBatched<T, Index, int64, elems>( \
|
||||||
|
ctx, params, indices, slice_size, out); \
|
||||||
|
} else { \
|
||||||
|
const int32 small_slice = static_cast<int32>(slice_size); \
|
||||||
|
bad_i = HandleCopiesBatched<T, Index, int32, elems>( \
|
||||||
|
ctx, params, indices, small_slice, out); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
// TODO(rmlarsen): Investigate whether these specializations are still
|
||||||
|
// needed and, if yes, whether the slice sizes are apropriate.
|
||||||
|
if (slice_size == 10)
|
||||||
|
CALL(10);
|
||||||
|
else if (slice_size == 20)
|
||||||
|
CALL(20);
|
||||||
|
else
|
||||||
|
CALL(-1);
|
||||||
|
#undef CALL
|
||||||
|
|
||||||
|
return bad_i;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T, typename Index>
|
||||||
|
struct GatherFunctorBatched {
|
||||||
|
int64 operator()(OpKernelContext* ctx,
|
||||||
|
typename TTypes<T, 4>::ConstTensor params,
|
||||||
|
typename TTypes<Index>::ConstFlat indices,
|
||||||
|
typename TTypes<T, 4>::Tensor out);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Index>
|
||||||
|
struct GatherFunctorBatched<CPUDevice, T, Index> {
|
||||||
|
int64 operator()(OpKernelContext* ctx,
|
||||||
|
typename TTypes<T, 4>::ConstTensor params,
|
||||||
|
typename TTypes<Index>::ConstFlat indices,
|
||||||
|
typename TTypes<T, 4>::Tensor out) {
|
||||||
|
return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Index>
|
||||||
|
struct GatherFunctorBatched<GPUDevice, Variant, Index> {
|
||||||
|
int64 operator()(OpKernelContext* ctx,
|
||||||
|
typename TTypes<Variant, 4>::ConstTensor params,
|
||||||
|
typename TTypes<Index>::ConstFlat indices,
|
||||||
|
typename TTypes<Variant, 4>::Tensor out) {
|
||||||
|
return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
|
46
tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc
Normal file
46
tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/gather_functor_batched_gpu.cu.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
|
||||||
|
template struct functor::GatherFunctorBatched<GPUDevice, T, Index>
|
||||||
|
|
||||||
|
#define DEFINE_GPU_SPECS(T) \
|
||||||
|
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||||
|
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||||
|
|
||||||
|
TF_CALL_bool(DEFINE_GPU_SPECS);
|
||||||
|
TF_CALL_int32(DEFINE_GPU_SPECS);
|
||||||
|
TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||||
|
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||||
|
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||||
|
|
||||||
|
#undef DEFINE_GPU_SPECS
|
||||||
|
#undef DEFINE_GPU_SPECS_INDEX
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
132
tensorflow/core/kernels/gather_functor_batched_gpu.cu.h
Normal file
132
tensorflow/core/kernels/gather_functor_batched_gpu.cu.h
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename T, typename Index,
|
||||||
|
bool is_axis_zero, bool is_batch_dims_zero>
|
||||||
|
__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
|
||||||
|
int64 outer_size,
|
||||||
|
int64 gather_dim_size, int64 indices_size,
|
||||||
|
int64 slice_size, int64 out_size) {
|
||||||
|
// params is a tensor of shape
|
||||||
|
// [batch_size, outer_size, gather_dim_size, slice_size].
|
||||||
|
GPU_1D_KERNEL_LOOP(i, out_size) {
|
||||||
|
Index batch_i = 0; // The batch index into params to use for i.
|
||||||
|
Index outer_i = 0; // The outer index into params to use for i.
|
||||||
|
Index indices_i = 0; // The index into indices to use for i.
|
||||||
|
Index slice_i = 0; // Index into the current slice in params to use for i.
|
||||||
|
|
||||||
|
const Index slices_count = i / slice_size;
|
||||||
|
if (is_batch_dims_zero) {
|
||||||
|
if (is_axis_zero) {
|
||||||
|
indices_i = slices_count;
|
||||||
|
} else {
|
||||||
|
outer_i = slices_count / indices_size;
|
||||||
|
indices_i = slices_count - outer_i * indices_size;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const Index entries_count = slices_count / indices_size;
|
||||||
|
if (is_axis_zero) {
|
||||||
|
batch_i = entries_count;
|
||||||
|
} else {
|
||||||
|
batch_i = entries_count / outer_size;
|
||||||
|
outer_i = entries_count - batch_i * outer_size;
|
||||||
|
}
|
||||||
|
indices_i = slices_count - entries_count * indices_size;
|
||||||
|
}
|
||||||
|
slice_i = i - slices_count * slice_size;
|
||||||
|
|
||||||
|
// Index into the gather axis to use for i.
|
||||||
|
Index gather_i = ldg(indices + batch_i * indices_size + indices_i);
|
||||||
|
|
||||||
|
// Check gather_i is in [0, gather_dim_size).
|
||||||
|
if (!FastBoundsCheck(gather_i, gather_dim_size)) {
|
||||||
|
// Set indices out of range to zero
|
||||||
|
// TODO(fpmc): Log an error for transfer back to host.
|
||||||
|
out[i] = T(0);
|
||||||
|
} else {
|
||||||
|
// Read params[batch_i, outer_i, gather_i, slice_i] and write it to the
|
||||||
|
// i'th position in out.
|
||||||
|
Index params_i = (
|
||||||
|
(batch_i * outer_size + outer_i) * gather_dim_size + gather_i
|
||||||
|
) * slice_size + slice_i;
|
||||||
|
out[i] = ldg(params + params_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
template <typename T, typename Index>
|
||||||
|
struct GatherFunctorBatched<GPUDevice, T, Index> {
|
||||||
|
int64 operator()(OpKernelContext* ctx,
|
||||||
|
typename TTypes<T, 4>::ConstTensor params,
|
||||||
|
typename TTypes<Index>::ConstFlat indices,
|
||||||
|
typename TTypes<T, 4>::Tensor out) {
|
||||||
|
const GPUDevice& d = ctx->eigen_gpu_device();
|
||||||
|
const int64 out_size = out.size();
|
||||||
|
if (out_size == 0) {
|
||||||
|
// We need a check here since the CPU version does useful error checking
|
||||||
|
// work if there are nonempty indices but empty slices, so the kernel is
|
||||||
|
// executed in that case. In the GPU case we don't know how to do error
|
||||||
|
// checking, so we skip the loop entirely.
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
const bool is_batch_dims_zero = params.dimension(0) == 1;
|
||||||
|
const bool is_axis_zero = params.dimension(1) == 1;
|
||||||
|
const int64 outer_size = params.dimension(1);
|
||||||
|
const int64 gather_dim_size = params.dimension(2);
|
||||||
|
const int64 indices_size = indices.size() / params.dimension(0);
|
||||||
|
const int64 slice_size = params.dimension(3);
|
||||||
|
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
|
||||||
|
const auto function = is_axis_zero ?
|
||||||
|
(is_batch_dims_zero ?
|
||||||
|
GatherOpKernel<T, Index, true, true>:
|
||||||
|
GatherOpKernel<T, Index, true, false>) :
|
||||||
|
(is_batch_dims_zero ?
|
||||||
|
GatherOpKernel<T, Index, false, true>:
|
||||||
|
GatherOpKernel<T, Index, false, false>);
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
|
function, config.block_count, config.thread_per_block, 0, d.stream(),
|
||||||
|
params.data(), indices.data(), out.data(),
|
||||||
|
outer_size, gather_dim_size, indices_size, slice_size, out_size));
|
||||||
|
// TODO(fpmc): enable indices validation on GPU.
|
||||||
|
// Right now checking for indicies out of bound in the kernel would
|
||||||
|
// require copying code between GPU/CPU, and thus slow.
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/variant.h"
|
#include "tensorflow/core/framework/variant.h"
|
||||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
#include "tensorflow/core/kernels/gather_functor.h"
|
#include "tensorflow/core/kernels/gather_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
@ -123,16 +124,22 @@ class GatherOp : public OpKernel {
|
|||||||
// The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
|
// The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
|
||||||
// params.shape[axis + 1:].
|
// params.shape[axis + 1:].
|
||||||
TensorShape result_shape;
|
TensorShape result_shape;
|
||||||
|
int64 batch_size = 1;
|
||||||
int64 outer_size = 1;
|
int64 outer_size = 1;
|
||||||
int64 inner_size = 1;
|
int64 inner_size = 1;
|
||||||
for (int i = 0; i < axis; i++) {
|
|
||||||
|
for (int i = 0; i < batch_dims_; ++i) {
|
||||||
|
result_shape.AddDim(params.dim_size(i));
|
||||||
|
batch_size *= params.dim_size(i);
|
||||||
|
}
|
||||||
|
for (int i = batch_dims_; i < axis; ++i) {
|
||||||
result_shape.AddDim(params.dim_size(i));
|
result_shape.AddDim(params.dim_size(i));
|
||||||
outer_size *= params.dim_size(i);
|
outer_size *= params.dim_size(i);
|
||||||
}
|
}
|
||||||
for (int i = batch_dims_; i < indices.dims(); ++i) {
|
for (int i = batch_dims_; i < indices.dims(); ++i) {
|
||||||
result_shape.AddDim(indices.dim_size(i));
|
result_shape.AddDim(indices.dim_size(i));
|
||||||
}
|
}
|
||||||
for (int i = axis + 1; i < params.dims(); i++) {
|
for (int i = axis + 1; i < params.dims(); ++i) {
|
||||||
result_shape.AddDim(params.dim_size(i));
|
result_shape.AddDim(params.dim_size(i));
|
||||||
inner_size *= params.dim_size(i);
|
inner_size *= params.dim_size(i);
|
||||||
}
|
}
|
||||||
@ -141,60 +148,29 @@ class GatherOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
||||||
if (N == 0) return;
|
if (N == 0) return;
|
||||||
|
|
||||||
|
int64 bad_i = -1;
|
||||||
|
auto indices_flat = indices.flat<Index>();
|
||||||
if (batch_dims_ > 0) {
|
if (batch_dims_ > 0) {
|
||||||
// TODO(virimia): Switch to transpose / gather with axis=0 / transpose
|
auto params_flat = params.shaped<T, 4>(
|
||||||
// on GPU, to avoid launching a lot of small kernels.
|
{batch_size, outer_size, gather_dim_size, inner_size});
|
||||||
|
auto out_flat = out->shaped<T, 4>(
|
||||||
|
{batch_size, outer_size, N / batch_size, inner_size});
|
||||||
|
|
||||||
// To avoid copying params (by transposing), run gather for each batch.
|
functor::GatherFunctorBatched<Device, T, Index> functor;
|
||||||
int64 batch_size = 1;
|
bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||||
for (int i = 0; i < batch_dims_; ++i) {
|
|
||||||
batch_size *= params.dim_size(i);
|
|
||||||
}
|
|
||||||
outer_size /= batch_size;
|
|
||||||
auto batched_params =
|
|
||||||
params.shaped<T, 2>({batch_size, params.NumElements() / batch_size});
|
|
||||||
auto batched_indices =
|
|
||||||
indices.shaped<Index, 2>({batch_size, N / batch_size});
|
|
||||||
auto batched_out =
|
|
||||||
out->shaped<T, 2>({batch_size, out->NumElements() / batch_size});
|
|
||||||
|
|
||||||
// TODO(virimia): Investigate the best performance, when the number of
|
|
||||||
// batches is large, between parallel vs sequential runs.
|
|
||||||
for (int64 batch = 0; batch < batch_size; ++batch) {
|
|
||||||
auto params_flat = typename TTypes<T, 3>::ConstTensor(
|
|
||||||
&batched_params(batch, 0), static_cast<IndexType>(outer_size),
|
|
||||||
static_cast<IndexType>(gather_dim_size),
|
|
||||||
static_cast<IndexType>(inner_size));
|
|
||||||
auto indices_flat = typename TTypes<Index>::ConstFlat(
|
|
||||||
&batched_indices(batch, 0), batched_indices.dimension(1));
|
|
||||||
auto out_flat = typename TTypes<T, 3>::Tensor(
|
|
||||||
&batched_out(batch, 0), static_cast<IndexType>(outer_size),
|
|
||||||
static_cast<IndexType>(N), static_cast<IndexType>(inner_size));
|
|
||||||
|
|
||||||
functor::GatherFunctor<Device, T, Index> functor;
|
|
||||||
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
c, bad_i < 0,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
|
||||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
auto params_flat =
|
auto params_flat =
|
||||||
params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
|
params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
|
||||||
auto indices_flat = indices.flat<Index>();
|
|
||||||
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
|
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
|
||||||
|
|
||||||
functor::GatherFunctor<Device, T, Index> functor;
|
functor::GatherFunctor<Device, T, Index> functor;
|
||||||
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
c, bad_i < 0,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
|
||||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
|
||||||
}
|
}
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, bad_i < 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||||
|
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -343,7 +343,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
|
result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
|
||||||
self.assertAllEqual(expected, result)
|
self.assertAllEqual(expected, result)
|
||||||
|
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 11):
|
with compat.forward_compatibility_horizon(2019, 8, 11):
|
||||||
result = array_ops.gather(
|
result = array_ops.gather(
|
||||||
params, indices, axis=axis, batch_dims=batch_dims)
|
params, indices, axis=axis, batch_dims=batch_dims)
|
||||||
|
|
||||||
@ -443,7 +443,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(output_shape, result.shape.as_list())
|
self.assertAllEqual(output_shape, result.shape.as_list())
|
||||||
self.assertAllEqual(expected, result)
|
self.assertAllEqual(expected, result)
|
||||||
|
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 11):
|
with compat.forward_compatibility_horizon(2019, 8, 11):
|
||||||
result = array_ops.gather(
|
result = array_ops.gather(
|
||||||
params, indices, axis=axis, batch_dims=batch_dims)
|
params, indices, axis=axis, batch_dims=batch_dims)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user