diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index e284353f2b0..d233fe63bad 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -129,6 +129,7 @@ tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/fused_eigen_output_kernels.cc tensorflow/core/kernels/gather_functor.cc +tensorflow/core/kernels/gather_functor_batched.cc tensorflow/core/kernels/gather_nd_op.cc tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 58e5664d28b..35caa3ac1a1 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1887,7 +1887,10 @@ tf_kernel_library( # Unlike gather_functor library, this does not include the CUDA code and deps. cc_library( name = "gather_functor_hdr", - hdrs = ["gather_functor.h"], + hdrs = [ + "gather_functor.h", + "gather_functor_batched.h", + ], ) tf_kernel_library( @@ -6031,6 +6034,7 @@ filegroup( "function_ops.cc", "function_ops.h", "gather_functor.h", + "gather_functor_batched.h", "gather_nd_op.cc", "gather_nd_op.h", "gather_nd_op_cpu_impl.h", diff --git a/tensorflow/core/kernels/gather_functor_batched.cc b/tensorflow/core/kernels/gather_functor_batched.cc new file mode 100644 index 00000000000..0960b3a2472 --- /dev/null +++ b/tensorflow/core/kernels/gather_functor_batched.cc @@ -0,0 +1,55 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "tensorflow/core/kernels/gather_functor_batched.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Forward declarations of the functor specializations for GPU. +#define DECLARE_GPU_SPECS_INDEX(T, Index) \ + template <> \ + int64 GatherFunctorBatched::operator()( \ + OpKernelContext* ctx, typename TTypes::ConstTensor Tparams, \ + typename TTypes::ConstFlat Tindices, \ + typename TTypes::Tensor Tout); \ + extern template struct GatherFunctorBatched; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPECS_INDEX(T, int32); \ + DECLARE_GPU_SPECS_INDEX(T, int64) + +TF_CALL_int64(DECLARE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_complex64(DECLARE_GPU_SPECS); +TF_CALL_complex128(DECLARE_GPU_SPECS); + +#undef DECLARE_GPU_SPECS +#undef DECLARE_GPU_SPECS_INDEX + +} // namespace functor +} // namespace tensorflow + +#else + +#include "tensorflow/core/kernels/gather_functor_batched.h" + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/gather_functor_batched.h b/tensorflow/core/kernels/gather_functor_batched.h new file mode 100644 index 00000000000..fa9ac72a3fd --- /dev/null +++ b/tensorflow/core/kernels/gather_functor_batched.h @@ -0,0 +1,197 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/platform/prefetch.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Helper method to copy using memcpy. +template +SliceIndex HandleCopiesBatched(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + SliceIndex slice_elems, + typename TTypes::Tensor out) { + const SliceIndex batch_size = static_cast(params.dimension(0)); + const SliceIndex outer_size = static_cast(params.dimension(1)); + const SliceIndex indices_size = + static_cast(indices.dimension(0)) / batch_size; + + const Index limit = static_cast(params.dimension(2)); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + slice_elems = static_slice_elems; + } + // Compute slice_bytes here so that static knowledge is available + const size_t slice_bytes = slice_elems * sizeof(T); + auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + mutex mu; + // Store the value of invalidate index for printing error information, it's a + // shared variable. + SliceIndex result = -1; + auto work = [&](int64 start, int64 end) { + const int64 r_start = start % (outer_size * indices_size); + SliceIndex batch_idx = static_cast( + start / (outer_size * indices_size)); + SliceIndex outer_idx = static_cast(r_start / indices_size); + SliceIndex indices_idx = static_cast(r_start % indices_size); + + SliceIndex batch_offset = batch_idx * indices_size; + for (; start < end; ++start) { + SliceIndex i_next = indices_idx + 1; + SliceIndex o_next = outer_idx; + SliceIndex b_next = batch_idx; + SliceIndex b_offset_next = batch_offset; + + if (i_next >= indices_size) { + i_next = 0; + if (++o_next >= outer_size) { + o_next = 0; + ++b_next; + b_offset_next += indices_size; + } + } + if (start + 1 < end) { + port::prefetch( + ¶ms(b_next, o_next, indices(b_offset_next + i_next), 0)); + port::prefetch(&out(b_next, o_next, i_next, 0)); + } + const Index index = internal::SubtleMustCopy( + indices(batch_offset + indices_idx)); + if (!FastBoundsCheck(index, limit)) { + mutex_lock l(mu); + result = batch_offset + indices_idx; + return; + } + + // Copy using memcpy if possible, otherwise an Eigen loop + // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve + // ahead-of-time compilation binary size). + if (is_simple_type::value) { + // Avoid auto-promotion to Index from SliceIndex by casting. + memcpy( + &out(batch_idx, outer_idx, indices_idx, 0), + ¶ms(batch_idx, outer_idx, static_cast(index), 0), + slice_bytes); + } else { + // For non-"simple" types (e.g. strings). + out.template chip<2>(indices_idx) = params.template chip<2>(index); + } + + indices_idx = i_next; + outer_idx = o_next; + batch_idx = b_next; + batch_offset = b_offset_next; + } + }; + + Shard(worker_threads->num_threads, worker_threads->workers, + batch_size * outer_size * indices_size, slice_elems * sizeof(T), work); + return result; +} + +template +struct GatherFunctorBatchedCPU { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + const int64 indices_size = indices.size(); // Includes the batch_size. + const int64 slice_size = out.dimension(3); + int64 bad_i; + + const int64 batch_size = params.dimension(0); + const int64 outer_size = params.dimension(1); + + bool use_large = (slice_size > std::numeric_limits::max() || + params.size() > std::numeric_limits::max() || + indices_size > std::numeric_limits::max() || + batch_size * outer_size * indices_size * slice_size > + std::numeric_limits::max()); +#define CALL(elems) \ + do { \ + if (use_large) { \ + bad_i = HandleCopiesBatched( \ + ctx, params, indices, slice_size, out); \ + } else { \ + const int32 small_slice = static_cast(slice_size); \ + bad_i = HandleCopiesBatched( \ + ctx, params, indices, small_slice, out); \ + } \ + } while (0) + + // TODO(rmlarsen): Investigate whether these specializations are still + // needed and, if yes, whether the slice sizes are apropriate. + if (slice_size == 10) + CALL(10); + else if (slice_size == 20) + CALL(20); + else + CALL(-1); +#undef CALL + + return bad_i; + } +}; + +template +struct GatherFunctorBatched { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out); +}; + +template +struct GatherFunctorBatched { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + return GatherFunctorBatchedCPU()(ctx, params, indices, out); + } +}; + +template +struct GatherFunctorBatched { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + return GatherFunctorBatchedCPU()(ctx, params, indices, out); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ diff --git a/tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc new file mode 100644 index 00000000000..f118d8dc72b --- /dev/null +++ b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/gather_functor_batched_gpu.cu.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_SPECS_INDEX(T, Index) \ + template struct functor::GatherFunctorBatched + +#define DEFINE_GPU_SPECS(T) \ + DEFINE_GPU_SPECS_INDEX(T, int32); \ + DEFINE_GPU_SPECS_INDEX(T, int64); + +TF_CALL_bool(DEFINE_GPU_SPECS); +TF_CALL_int32(DEFINE_GPU_SPECS); +TF_CALL_int64(DEFINE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); +TF_CALL_complex64(DEFINE_GPU_SPECS); +TF_CALL_complex128(DEFINE_GPU_SPECS); + +#undef DEFINE_GPU_SPECS +#undef DEFINE_GPU_SPECS_INDEX + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h new file mode 100644 index 00000000000..24c23f1f900 --- /dev/null +++ b/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h @@ -0,0 +1,132 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gather_functor_batched.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +__global__ void GatherOpKernel(const T* params, const Index* indices, T* out, + int64 outer_size, + int64 gather_dim_size, int64 indices_size, + int64 slice_size, int64 out_size) { + // params is a tensor of shape + // [batch_size, outer_size, gather_dim_size, slice_size]. + GPU_1D_KERNEL_LOOP(i, out_size) { + Index batch_i = 0; // The batch index into params to use for i. + Index outer_i = 0; // The outer index into params to use for i. + Index indices_i = 0; // The index into indices to use for i. + Index slice_i = 0; // Index into the current slice in params to use for i. + + const Index slices_count = i / slice_size; + if (is_batch_dims_zero) { + if (is_axis_zero) { + indices_i = slices_count; + } else { + outer_i = slices_count / indices_size; + indices_i = slices_count - outer_i * indices_size; + } + } else { + const Index entries_count = slices_count / indices_size; + if (is_axis_zero) { + batch_i = entries_count; + } else { + batch_i = entries_count / outer_size; + outer_i = entries_count - batch_i * outer_size; + } + indices_i = slices_count - entries_count * indices_size; + } + slice_i = i - slices_count * slice_size; + + // Index into the gather axis to use for i. + Index gather_i = ldg(indices + batch_i * indices_size + indices_i); + + // Check gather_i is in [0, gather_dim_size). + if (!FastBoundsCheck(gather_i, gather_dim_size)) { + // Set indices out of range to zero + // TODO(fpmc): Log an error for transfer back to host. + out[i] = T(0); + } else { + // Read params[batch_i, outer_i, gather_i, slice_i] and write it to the + // i'th position in out. + Index params_i = ( + (batch_i * outer_size + outer_i) * gather_dim_size + gather_i + ) * slice_size + slice_i; + out[i] = ldg(params + params_i); + } + } +} + +namespace functor { +template +struct GatherFunctorBatched { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + const GPUDevice& d = ctx->eigen_gpu_device(); + const int64 out_size = out.size(); + if (out_size == 0) { + // We need a check here since the CPU version does useful error checking + // work if there are nonempty indices but empty slices, so the kernel is + // executed in that case. In the GPU case we don't know how to do error + // checking, so we skip the loop entirely. + return -1; + } + const bool is_batch_dims_zero = params.dimension(0) == 1; + const bool is_axis_zero = params.dimension(1) == 1; + const int64 outer_size = params.dimension(1); + const int64 gather_dim_size = params.dimension(2); + const int64 indices_size = indices.size() / params.dimension(0); + const int64 slice_size = params.dimension(3); + + GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d); + const auto function = is_axis_zero ? + (is_batch_dims_zero ? + GatherOpKernel: + GatherOpKernel) : + (is_batch_dims_zero ? + GatherOpKernel: + GatherOpKernel); + TF_CHECK_OK(GpuLaunchKernel( + function, config.block_count, config.thread_per_block, 0, d.stream(), + params.data(), indices.data(), out.data(), + outer_size, gather_dim_size, indices_size, slice_size, out_size)); + // TODO(fpmc): enable indices validation on GPU. + // Right now checking for indicies out of bound in the kernel would + // require copying code between GPU/CPU, and thus slow. + return -1; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_ diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 68c258da6ad..38e0bab676d 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/kernels/gather_functor_batched.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" @@ -123,16 +124,22 @@ class GatherOp : public OpKernel { // The result shape is params.shape[:axis] + indices.shape[batch_dims:] + // params.shape[axis + 1:]. TensorShape result_shape; + int64 batch_size = 1; int64 outer_size = 1; int64 inner_size = 1; - for (int i = 0; i < axis; i++) { + + for (int i = 0; i < batch_dims_; ++i) { + result_shape.AddDim(params.dim_size(i)); + batch_size *= params.dim_size(i); + } + for (int i = batch_dims_; i < axis; ++i) { result_shape.AddDim(params.dim_size(i)); outer_size *= params.dim_size(i); } for (int i = batch_dims_; i < indices.dims(); ++i) { result_shape.AddDim(indices.dim_size(i)); } - for (int i = axis + 1; i < params.dims(); i++) { + for (int i = axis + 1; i < params.dims(); ++i) { result_shape.AddDim(params.dim_size(i)); inner_size *= params.dim_size(i); } @@ -141,60 +148,29 @@ class GatherOp : public OpKernel { OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); if (N == 0) return; + int64 bad_i = -1; + auto indices_flat = indices.flat(); if (batch_dims_ > 0) { - // TODO(virimia): Switch to transpose / gather with axis=0 / transpose - // on GPU, to avoid launching a lot of small kernels. + auto params_flat = params.shaped( + {batch_size, outer_size, gather_dim_size, inner_size}); + auto out_flat = out->shaped( + {batch_size, outer_size, N / batch_size, inner_size}); - // To avoid copying params (by transposing), run gather for each batch. - int64 batch_size = 1; - for (int i = 0; i < batch_dims_; ++i) { - batch_size *= params.dim_size(i); - } - outer_size /= batch_size; - auto batched_params = - params.shaped({batch_size, params.NumElements() / batch_size}); - auto batched_indices = - indices.shaped({batch_size, N / batch_size}); - auto batched_out = - out->shaped({batch_size, out->NumElements() / batch_size}); - - // TODO(virimia): Investigate the best performance, when the number of - // batches is large, between parallel vs sequential runs. - for (int64 batch = 0; batch < batch_size; ++batch) { - auto params_flat = typename TTypes::ConstTensor( - &batched_params(batch, 0), static_cast(outer_size), - static_cast(gather_dim_size), - static_cast(inner_size)); - auto indices_flat = typename TTypes::ConstFlat( - &batched_indices(batch, 0), batched_indices.dimension(1)); - auto out_flat = typename TTypes::Tensor( - &batched_out(batch, 0), static_cast(outer_size), - static_cast(N), static_cast(inner_size)); - - functor::GatherFunctor functor; - const int64 bad_i = functor(c, params_flat, indices_flat, out_flat); - - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")")); - } + functor::GatherFunctorBatched functor; + bad_i = functor(c, params_flat, indices_flat, out_flat); } else { auto params_flat = params.shaped({outer_size, gather_dim_size, inner_size}); - auto indices_flat = indices.flat(); auto out_flat = out->shaped({outer_size, N, inner_size}); functor::GatherFunctor functor; - const int64 bad_i = functor(c, params_flat, indices_flat, out_flat); - - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")")); + bad_i = functor(c, params_flat, indices_flat, out_flat); } + OP_REQUIRES( + c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), " = ", + indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")")); } private: diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index f23b7d33664..031389cd349 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -343,7 +343,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) self.assertAllEqual(expected, result) - with compat.forward_compatibility_horizon(2019, 6, 11): + with compat.forward_compatibility_horizon(2019, 8, 11): result = array_ops.gather( params, indices, axis=axis, batch_dims=batch_dims) @@ -443,7 +443,7 @@ class GatherTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(output_shape, result.shape.as_list()) self.assertAllEqual(expected, result) - with compat.forward_compatibility_horizon(2019, 6, 11): + with compat.forward_compatibility_horizon(2019, 8, 11): result = array_ops.gather( params, indices, axis=axis, batch_dims=batch_dims)