Created a Gather functor that works with batch dimensions.

PiperOrigin-RevId: 261887312
This commit is contained in:
A. Unique TensorFlower 2019-08-06 04:39:30 -07:00 committed by TensorFlower Gardener
parent ee202801b3
commit ecc0b95092
8 changed files with 461 additions and 50 deletions

View File

@ -129,6 +129,7 @@ tensorflow/core/kernels/function_ops.cc
tensorflow/core/kernels/fused_batch_norm_op.cc
tensorflow/core/kernels/fused_eigen_output_kernels.cc
tensorflow/core/kernels/gather_functor.cc
tensorflow/core/kernels/gather_functor_batched.cc
tensorflow/core/kernels/gather_nd_op.cc
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc

View File

@ -1887,7 +1887,10 @@ tf_kernel_library(
# Unlike gather_functor library, this does not include the CUDA code and deps.
cc_library(
name = "gather_functor_hdr",
hdrs = ["gather_functor.h"],
hdrs = [
"gather_functor.h",
"gather_functor_batched.h",
],
)
tf_kernel_library(
@ -6031,6 +6034,7 @@ filegroup(
"function_ops.cc",
"function_ops.h",
"gather_functor.h",
"gather_functor_batched.h",
"gather_nd_op.cc",
"gather_nd_op.h",
"gather_nd_op_cpu_impl.h",

View File

@ -0,0 +1,55 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/gather_functor_batched.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Forward declarations of the functor specializations for GPU.
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
template <> \
int64 GatherFunctorBatched<GPUDevice, T, Index>::operator()( \
OpKernelContext* ctx, typename TTypes<T, 4>::ConstTensor Tparams, \
typename TTypes<Index>::ConstFlat Tindices, \
typename TTypes<T, 4>::Tensor Tout); \
extern template struct GatherFunctorBatched<GPUDevice, T, Index>;
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64)
TF_CALL_int64(DECLARE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
TF_CALL_complex64(DECLARE_GPU_SPECS);
TF_CALL_complex128(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX
} // namespace functor
} // namespace tensorflow
#else
#include "tensorflow/core/kernels/gather_functor_batched.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -0,0 +1,197 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Helper method to copy using memcpy.
template <typename T, typename Index, typename SliceIndex,
SliceIndex static_slice_elems>
SliceIndex HandleCopiesBatched(OpKernelContext* ctx,
typename TTypes<T, 4>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
SliceIndex slice_elems,
typename TTypes<T, 4>::Tensor out) {
const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1));
const SliceIndex indices_size =
static_cast<SliceIndex>(indices.dimension(0)) / batch_size;
const Index limit = static_cast<Index>(params.dimension(2));
if (static_slice_elems >= 0) {
// Give compiler static knowledge of the number of elements/bytes
slice_elems = static_slice_elems;
}
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
mutex mu;
// Store the value of invalidate index for printing error information, it's a
// shared variable.
SliceIndex result = -1;
auto work = [&](int64 start, int64 end) {
const int64 r_start = start % (outer_size * indices_size);
SliceIndex batch_idx = static_cast<SliceIndex>(
start / (outer_size * indices_size));
SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size);
SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size);
SliceIndex batch_offset = batch_idx * indices_size;
for (; start < end; ++start) {
SliceIndex i_next = indices_idx + 1;
SliceIndex o_next = outer_idx;
SliceIndex b_next = batch_idx;
SliceIndex b_offset_next = batch_offset;
if (i_next >= indices_size) {
i_next = 0;
if (++o_next >= outer_size) {
o_next = 0;
++b_next;
b_offset_next += indices_size;
}
}
if (start + 1 < end) {
port::prefetch<port::PREFETCH_HINT_T0>(
&params(b_next, o_next, indices(b_offset_next + i_next), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0));
}
const Index index = internal::SubtleMustCopy(
indices(batch_offset + indices_idx));
if (!FastBoundsCheck(index, limit)) {
mutex_lock l(mu);
result = batch_offset + indices_idx;
return;
}
// Copy using memcpy if possible, otherwise an Eigen loop
// TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
// ahead-of-time compilation binary size).
if (is_simple_type<T>::value) {
// Avoid auto-promotion to Index from SliceIndex by casting.
memcpy(
&out(batch_idx, outer_idx, indices_idx, 0),
&params(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0),
slice_bytes);
} else {
// For non-"simple" types (e.g. strings).
out.template chip<2>(indices_idx) = params.template chip<2>(index);
}
indices_idx = i_next;
outer_idx = o_next;
batch_idx = b_next;
batch_offset = b_offset_next;
}
};
Shard(worker_threads->num_threads, worker_threads->workers,
batch_size * outer_size * indices_size, slice_elems * sizeof(T), work);
return result;
}
template <typename T, typename Index>
struct GatherFunctorBatchedCPU {
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 4>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 4>::Tensor out) {
const int64 indices_size = indices.size(); // Includes the batch_size.
const int64 slice_size = out.dimension(3);
int64 bad_i;
const int64 batch_size = params.dimension(0);
const int64 outer_size = params.dimension(1);
bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
params.size() > std::numeric_limits<int32>::max() ||
indices_size > std::numeric_limits<int32>::max() ||
batch_size * outer_size * indices_size * slice_size >
std::numeric_limits<int32>::max());
#define CALL(elems) \
do { \
if (use_large) { \
bad_i = HandleCopiesBatched<T, Index, int64, elems>( \
ctx, params, indices, slice_size, out); \
} else { \
const int32 small_slice = static_cast<int32>(slice_size); \
bad_i = HandleCopiesBatched<T, Index, int32, elems>( \
ctx, params, indices, small_slice, out); \
} \
} while (0)
// TODO(rmlarsen): Investigate whether these specializations are still
// needed and, if yes, whether the slice sizes are apropriate.
if (slice_size == 10)
CALL(10);
else if (slice_size == 20)
CALL(20);
else
CALL(-1);
#undef CALL
return bad_i;
}
};
template <typename Device, typename T, typename Index>
struct GatherFunctorBatched {
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 4>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 4>::Tensor out);
};
template <typename T, typename Index>
struct GatherFunctorBatched<CPUDevice, T, Index> {
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 4>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 4>::Tensor out) {
return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out);
}
};
template <typename Index>
struct GatherFunctorBatched<GPUDevice, Variant, Index> {
int64 operator()(OpKernelContext* ctx,
typename TTypes<Variant, 4>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<Variant, 4>::Tensor out) {
return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out);
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/gather_functor_batched_gpu.cu.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
template struct functor::GatherFunctorBatched<GPUDevice, T, Index>
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
DEFINE_GPU_SPECS_INDEX(T, int64);
TF_CALL_bool(DEFINE_GPU_SPECS);
TF_CALL_int32(DEFINE_GPU_SPECS);
TF_CALL_int64(DEFINE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
TF_CALL_complex64(DEFINE_GPU_SPECS);
TF_CALL_complex128(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
#undef DEFINE_GPU_SPECS_INDEX
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -0,0 +1,132 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/gather_functor_batched.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
template <typename T, typename Index,
bool is_axis_zero, bool is_batch_dims_zero>
__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
int64 outer_size,
int64 gather_dim_size, int64 indices_size,
int64 slice_size, int64 out_size) {
// params is a tensor of shape
// [batch_size, outer_size, gather_dim_size, slice_size].
GPU_1D_KERNEL_LOOP(i, out_size) {
Index batch_i = 0; // The batch index into params to use for i.
Index outer_i = 0; // The outer index into params to use for i.
Index indices_i = 0; // The index into indices to use for i.
Index slice_i = 0; // Index into the current slice in params to use for i.
const Index slices_count = i / slice_size;
if (is_batch_dims_zero) {
if (is_axis_zero) {
indices_i = slices_count;
} else {
outer_i = slices_count / indices_size;
indices_i = slices_count - outer_i * indices_size;
}
} else {
const Index entries_count = slices_count / indices_size;
if (is_axis_zero) {
batch_i = entries_count;
} else {
batch_i = entries_count / outer_size;
outer_i = entries_count - batch_i * outer_size;
}
indices_i = slices_count - entries_count * indices_size;
}
slice_i = i - slices_count * slice_size;
// Index into the gather axis to use for i.
Index gather_i = ldg(indices + batch_i * indices_size + indices_i);
// Check gather_i is in [0, gather_dim_size).
if (!FastBoundsCheck(gather_i, gather_dim_size)) {
// Set indices out of range to zero
// TODO(fpmc): Log an error for transfer back to host.
out[i] = T(0);
} else {
// Read params[batch_i, outer_i, gather_i, slice_i] and write it to the
// i'th position in out.
Index params_i = (
(batch_i * outer_size + outer_i) * gather_dim_size + gather_i
) * slice_size + slice_i;
out[i] = ldg(params + params_i);
}
}
}
namespace functor {
template <typename T, typename Index>
struct GatherFunctorBatched<GPUDevice, T, Index> {
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 4>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 4>::Tensor out) {
const GPUDevice& d = ctx->eigen_gpu_device();
const int64 out_size = out.size();
if (out_size == 0) {
// We need a check here since the CPU version does useful error checking
// work if there are nonempty indices but empty slices, so the kernel is
// executed in that case. In the GPU case we don't know how to do error
// checking, so we skip the loop entirely.
return -1;
}
const bool is_batch_dims_zero = params.dimension(0) == 1;
const bool is_axis_zero = params.dimension(1) == 1;
const int64 outer_size = params.dimension(1);
const int64 gather_dim_size = params.dimension(2);
const int64 indices_size = indices.size() / params.dimension(0);
const int64 slice_size = params.dimension(3);
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
const auto function = is_axis_zero ?
(is_batch_dims_zero ?
GatherOpKernel<T, Index, true, true>:
GatherOpKernel<T, Index, true, false>) :
(is_batch_dims_zero ?
GatherOpKernel<T, Index, false, true>:
GatherOpKernel<T, Index, false, false>);
TF_CHECK_OK(GpuLaunchKernel(
function, config.block_count, config.thread_per_block, 0, d.stream(),
params.data(), indices.data(), out.data(),
outer_size, gather_dim_size, indices_size, slice_size, out_size));
// TODO(fpmc): enable indices validation on GPU.
// Right now checking for indicies out of bound in the kernel would
// require copying code between GPU/CPU, and thus slow.
return -1;
}
};
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/gather_functor.h"
#include "tensorflow/core/kernels/gather_functor_batched.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
@ -123,16 +124,22 @@ class GatherOp : public OpKernel {
// The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
// params.shape[axis + 1:].
TensorShape result_shape;
int64 batch_size = 1;
int64 outer_size = 1;
int64 inner_size = 1;
for (int i = 0; i < axis; i++) {
for (int i = 0; i < batch_dims_; ++i) {
result_shape.AddDim(params.dim_size(i));
batch_size *= params.dim_size(i);
}
for (int i = batch_dims_; i < axis; ++i) {
result_shape.AddDim(params.dim_size(i));
outer_size *= params.dim_size(i);
}
for (int i = batch_dims_; i < indices.dims(); ++i) {
result_shape.AddDim(indices.dim_size(i));
}
for (int i = axis + 1; i < params.dims(); i++) {
for (int i = axis + 1; i < params.dims(); ++i) {
result_shape.AddDim(params.dim_size(i));
inner_size *= params.dim_size(i);
}
@ -141,60 +148,29 @@ class GatherOp : public OpKernel {
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
if (N == 0) return;
int64 bad_i = -1;
auto indices_flat = indices.flat<Index>();
if (batch_dims_ > 0) {
// TODO(virimia): Switch to transpose / gather with axis=0 / transpose
// on GPU, to avoid launching a lot of small kernels.
auto params_flat = params.shaped<T, 4>(
{batch_size, outer_size, gather_dim_size, inner_size});
auto out_flat = out->shaped<T, 4>(
{batch_size, outer_size, N / batch_size, inner_size});
// To avoid copying params (by transposing), run gather for each batch.
int64 batch_size = 1;
for (int i = 0; i < batch_dims_; ++i) {
batch_size *= params.dim_size(i);
}
outer_size /= batch_size;
auto batched_params =
params.shaped<T, 2>({batch_size, params.NumElements() / batch_size});
auto batched_indices =
indices.shaped<Index, 2>({batch_size, N / batch_size});
auto batched_out =
out->shaped<T, 2>({batch_size, out->NumElements() / batch_size});
// TODO(virimia): Investigate the best performance, when the number of
// batches is large, between parallel vs sequential runs.
for (int64 batch = 0; batch < batch_size; ++batch) {
auto params_flat = typename TTypes<T, 3>::ConstTensor(
&batched_params(batch, 0), static_cast<IndexType>(outer_size),
static_cast<IndexType>(gather_dim_size),
static_cast<IndexType>(inner_size));
auto indices_flat = typename TTypes<Index>::ConstFlat(
&batched_indices(batch, 0), batched_indices.dimension(1));
auto out_flat = typename TTypes<T, 3>::Tensor(
&batched_out(batch, 0), static_cast<IndexType>(outer_size),
static_cast<IndexType>(N), static_cast<IndexType>(inner_size));
functor::GatherFunctor<Device, T, Index> functor;
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
}
functor::GatherFunctorBatched<Device, T, Index> functor;
bad_i = functor(c, params_flat, indices_flat, out_flat);
} else {
auto params_flat =
params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
auto indices_flat = indices.flat<Index>();
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
functor::GatherFunctor<Device, T, Index> functor;
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
bad_i = functor(c, params_flat, indices_flat, out_flat);
}
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
}
private:

View File

@ -343,7 +343,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
self.assertAllEqual(expected, result)
with compat.forward_compatibility_horizon(2019, 6, 11):
with compat.forward_compatibility_horizon(2019, 8, 11):
result = array_ops.gather(
params, indices, axis=axis, batch_dims=batch_dims)
@ -443,7 +443,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(output_shape, result.shape.as_list())
self.assertAllEqual(expected, result)
with compat.forward_compatibility_horizon(2019, 6, 11):
with compat.forward_compatibility_horizon(2019, 8, 11):
result = array_ops.gather(
params, indices, axis=axis, batch_dims=batch_dims)