Created a Gather functor that works with batch dimensions.
PiperOrigin-RevId: 261887312
This commit is contained in:
parent
ee202801b3
commit
ecc0b95092
@ -129,6 +129,7 @@ tensorflow/core/kernels/function_ops.cc
|
||||
tensorflow/core/kernels/fused_batch_norm_op.cc
|
||||
tensorflow/core/kernels/fused_eigen_output_kernels.cc
|
||||
tensorflow/core/kernels/gather_functor.cc
|
||||
tensorflow/core/kernels/gather_functor_batched.cc
|
||||
tensorflow/core/kernels/gather_nd_op.cc
|
||||
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
|
||||
tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
|
||||
|
@ -1887,7 +1887,10 @@ tf_kernel_library(
|
||||
# Unlike gather_functor library, this does not include the CUDA code and deps.
|
||||
cc_library(
|
||||
name = "gather_functor_hdr",
|
||||
hdrs = ["gather_functor.h"],
|
||||
hdrs = [
|
||||
"gather_functor.h",
|
||||
"gather_functor_batched.h",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -6031,6 +6034,7 @@ filegroup(
|
||||
"function_ops.cc",
|
||||
"function_ops.h",
|
||||
"gather_functor.h",
|
||||
"gather_functor_batched.h",
|
||||
"gather_nd_op.cc",
|
||||
"gather_nd_op.h",
|
||||
"gather_nd_op_cpu_impl.h",
|
||||
|
55
tensorflow/core/kernels/gather_functor_batched.cc
Normal file
55
tensorflow/core/kernels/gather_functor_batched.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
template <> \
|
||||
int64 GatherFunctorBatched<GPUDevice, T, Index>::operator()( \
|
||||
OpKernelContext* ctx, typename TTypes<T, 4>::ConstTensor Tparams, \
|
||||
typename TTypes<Index>::ConstFlat Tindices, \
|
||||
typename TTypes<T, 4>::Tensor Tout); \
|
||||
extern template struct GatherFunctorBatched<GPUDevice, T, Index>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64)
|
||||
|
||||
TF_CALL_int64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#else
|
||||
|
||||
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
197
tensorflow/core/kernels/gather_functor_batched.h
Normal file
197
tensorflow/core/kernels/gather_functor_batched.h
Normal file
@ -0,0 +1,197 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/platform/prefetch.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Helper method to copy using memcpy.
|
||||
template <typename T, typename Index, typename SliceIndex,
|
||||
SliceIndex static_slice_elems>
|
||||
SliceIndex HandleCopiesBatched(OpKernelContext* ctx,
|
||||
typename TTypes<T, 4>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
SliceIndex slice_elems,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
|
||||
const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1));
|
||||
const SliceIndex indices_size =
|
||||
static_cast<SliceIndex>(indices.dimension(0)) / batch_size;
|
||||
|
||||
const Index limit = static_cast<Index>(params.dimension(2));
|
||||
if (static_slice_elems >= 0) {
|
||||
// Give compiler static knowledge of the number of elements/bytes
|
||||
slice_elems = static_slice_elems;
|
||||
}
|
||||
// Compute slice_bytes here so that static knowledge is available
|
||||
const size_t slice_bytes = slice_elems * sizeof(T);
|
||||
auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
|
||||
mutex mu;
|
||||
// Store the value of invalidate index for printing error information, it's a
|
||||
// shared variable.
|
||||
SliceIndex result = -1;
|
||||
auto work = [&](int64 start, int64 end) {
|
||||
const int64 r_start = start % (outer_size * indices_size);
|
||||
SliceIndex batch_idx = static_cast<SliceIndex>(
|
||||
start / (outer_size * indices_size));
|
||||
SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size);
|
||||
SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size);
|
||||
|
||||
SliceIndex batch_offset = batch_idx * indices_size;
|
||||
for (; start < end; ++start) {
|
||||
SliceIndex i_next = indices_idx + 1;
|
||||
SliceIndex o_next = outer_idx;
|
||||
SliceIndex b_next = batch_idx;
|
||||
SliceIndex b_offset_next = batch_offset;
|
||||
|
||||
if (i_next >= indices_size) {
|
||||
i_next = 0;
|
||||
if (++o_next >= outer_size) {
|
||||
o_next = 0;
|
||||
++b_next;
|
||||
b_offset_next += indices_size;
|
||||
}
|
||||
}
|
||||
if (start + 1 < end) {
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(
|
||||
¶ms(b_next, o_next, indices(b_offset_next + i_next), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0));
|
||||
}
|
||||
const Index index = internal::SubtleMustCopy(
|
||||
indices(batch_offset + indices_idx));
|
||||
if (!FastBoundsCheck(index, limit)) {
|
||||
mutex_lock l(mu);
|
||||
result = batch_offset + indices_idx;
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy using memcpy if possible, otherwise an Eigen loop
|
||||
// TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
|
||||
// ahead-of-time compilation binary size).
|
||||
if (is_simple_type<T>::value) {
|
||||
// Avoid auto-promotion to Index from SliceIndex by casting.
|
||||
memcpy(
|
||||
&out(batch_idx, outer_idx, indices_idx, 0),
|
||||
¶ms(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0),
|
||||
slice_bytes);
|
||||
} else {
|
||||
// For non-"simple" types (e.g. strings).
|
||||
out.template chip<2>(indices_idx) = params.template chip<2>(index);
|
||||
}
|
||||
|
||||
indices_idx = i_next;
|
||||
outer_idx = o_next;
|
||||
batch_idx = b_next;
|
||||
batch_offset = b_offset_next;
|
||||
}
|
||||
};
|
||||
|
||||
Shard(worker_threads->num_threads, worker_threads->workers,
|
||||
batch_size * outer_size * indices_size, slice_elems * sizeof(T), work);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T, typename Index>
|
||||
struct GatherFunctorBatchedCPU {
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 4>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
const int64 indices_size = indices.size(); // Includes the batch_size.
|
||||
const int64 slice_size = out.dimension(3);
|
||||
int64 bad_i;
|
||||
|
||||
const int64 batch_size = params.dimension(0);
|
||||
const int64 outer_size = params.dimension(1);
|
||||
|
||||
bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
|
||||
params.size() > std::numeric_limits<int32>::max() ||
|
||||
indices_size > std::numeric_limits<int32>::max() ||
|
||||
batch_size * outer_size * indices_size * slice_size >
|
||||
std::numeric_limits<int32>::max());
|
||||
#define CALL(elems) \
|
||||
do { \
|
||||
if (use_large) { \
|
||||
bad_i = HandleCopiesBatched<T, Index, int64, elems>( \
|
||||
ctx, params, indices, slice_size, out); \
|
||||
} else { \
|
||||
const int32 small_slice = static_cast<int32>(slice_size); \
|
||||
bad_i = HandleCopiesBatched<T, Index, int32, elems>( \
|
||||
ctx, params, indices, small_slice, out); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// TODO(rmlarsen): Investigate whether these specializations are still
|
||||
// needed and, if yes, whether the slice sizes are apropriate.
|
||||
if (slice_size == 10)
|
||||
CALL(10);
|
||||
else if (slice_size == 20)
|
||||
CALL(20);
|
||||
else
|
||||
CALL(-1);
|
||||
#undef CALL
|
||||
|
||||
return bad_i;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, typename Index>
|
||||
struct GatherFunctorBatched {
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 4>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 4>::Tensor out);
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
struct GatherFunctorBatched<CPUDevice, T, Index> {
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 4>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Index>
|
||||
struct GatherFunctorBatched<GPUDevice, Variant, Index> {
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<Variant, 4>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<Variant, 4>::Tensor out) {
|
||||
return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
|
46
tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc
Normal file
46
tensorflow/core/kernels/gather_functor_batched_gpu.cu.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/gather_functor_batched_gpu.cu.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
|
||||
template struct functor::GatherFunctorBatched<GPUDevice, T, Index>
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
TF_CALL_bool(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int32(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
#undef DEFINE_GPU_SPECS_INDEX
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
132
tensorflow/core/kernels/gather_functor_batched_gpu.cu.h
Normal file
132
tensorflow/core/kernels/gather_functor_batched_gpu.cu.h
Normal file
@ -0,0 +1,132 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename T, typename Index,
|
||||
bool is_axis_zero, bool is_batch_dims_zero>
|
||||
__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
|
||||
int64 outer_size,
|
||||
int64 gather_dim_size, int64 indices_size,
|
||||
int64 slice_size, int64 out_size) {
|
||||
// params is a tensor of shape
|
||||
// [batch_size, outer_size, gather_dim_size, slice_size].
|
||||
GPU_1D_KERNEL_LOOP(i, out_size) {
|
||||
Index batch_i = 0; // The batch index into params to use for i.
|
||||
Index outer_i = 0; // The outer index into params to use for i.
|
||||
Index indices_i = 0; // The index into indices to use for i.
|
||||
Index slice_i = 0; // Index into the current slice in params to use for i.
|
||||
|
||||
const Index slices_count = i / slice_size;
|
||||
if (is_batch_dims_zero) {
|
||||
if (is_axis_zero) {
|
||||
indices_i = slices_count;
|
||||
} else {
|
||||
outer_i = slices_count / indices_size;
|
||||
indices_i = slices_count - outer_i * indices_size;
|
||||
}
|
||||
} else {
|
||||
const Index entries_count = slices_count / indices_size;
|
||||
if (is_axis_zero) {
|
||||
batch_i = entries_count;
|
||||
} else {
|
||||
batch_i = entries_count / outer_size;
|
||||
outer_i = entries_count - batch_i * outer_size;
|
||||
}
|
||||
indices_i = slices_count - entries_count * indices_size;
|
||||
}
|
||||
slice_i = i - slices_count * slice_size;
|
||||
|
||||
// Index into the gather axis to use for i.
|
||||
Index gather_i = ldg(indices + batch_i * indices_size + indices_i);
|
||||
|
||||
// Check gather_i is in [0, gather_dim_size).
|
||||
if (!FastBoundsCheck(gather_i, gather_dim_size)) {
|
||||
// Set indices out of range to zero
|
||||
// TODO(fpmc): Log an error for transfer back to host.
|
||||
out[i] = T(0);
|
||||
} else {
|
||||
// Read params[batch_i, outer_i, gather_i, slice_i] and write it to the
|
||||
// i'th position in out.
|
||||
Index params_i = (
|
||||
(batch_i * outer_size + outer_i) * gather_dim_size + gather_i
|
||||
) * slice_size + slice_i;
|
||||
out[i] = ldg(params + params_i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace functor {
|
||||
template <typename T, typename Index>
|
||||
struct GatherFunctorBatched<GPUDevice, T, Index> {
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 4>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
const GPUDevice& d = ctx->eigen_gpu_device();
|
||||
const int64 out_size = out.size();
|
||||
if (out_size == 0) {
|
||||
// We need a check here since the CPU version does useful error checking
|
||||
// work if there are nonempty indices but empty slices, so the kernel is
|
||||
// executed in that case. In the GPU case we don't know how to do error
|
||||
// checking, so we skip the loop entirely.
|
||||
return -1;
|
||||
}
|
||||
const bool is_batch_dims_zero = params.dimension(0) == 1;
|
||||
const bool is_axis_zero = params.dimension(1) == 1;
|
||||
const int64 outer_size = params.dimension(1);
|
||||
const int64 gather_dim_size = params.dimension(2);
|
||||
const int64 indices_size = indices.size() / params.dimension(0);
|
||||
const int64 slice_size = params.dimension(3);
|
||||
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
|
||||
const auto function = is_axis_zero ?
|
||||
(is_batch_dims_zero ?
|
||||
GatherOpKernel<T, Index, true, true>:
|
||||
GatherOpKernel<T, Index, true, false>) :
|
||||
(is_batch_dims_zero ?
|
||||
GatherOpKernel<T, Index, false, true>:
|
||||
GatherOpKernel<T, Index, false, false>);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
function, config.block_count, config.thread_per_block, 0, d.stream(),
|
||||
params.data(), indices.data(), out.data(),
|
||||
outer_size, gather_dim_size, indices_size, slice_size, out_size));
|
||||
// TODO(fpmc): enable indices validation on GPU.
|
||||
// Right now checking for indicies out of bound in the kernel would
|
||||
// require copying code between GPU/CPU, and thus slow.
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/gather_functor.h"
|
||||
#include "tensorflow/core/kernels/gather_functor_batched.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
@ -123,16 +124,22 @@ class GatherOp : public OpKernel {
|
||||
// The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
|
||||
// params.shape[axis + 1:].
|
||||
TensorShape result_shape;
|
||||
int64 batch_size = 1;
|
||||
int64 outer_size = 1;
|
||||
int64 inner_size = 1;
|
||||
for (int i = 0; i < axis; i++) {
|
||||
|
||||
for (int i = 0; i < batch_dims_; ++i) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
batch_size *= params.dim_size(i);
|
||||
}
|
||||
for (int i = batch_dims_; i < axis; ++i) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
outer_size *= params.dim_size(i);
|
||||
}
|
||||
for (int i = batch_dims_; i < indices.dims(); ++i) {
|
||||
result_shape.AddDim(indices.dim_size(i));
|
||||
}
|
||||
for (int i = axis + 1; i < params.dims(); i++) {
|
||||
for (int i = axis + 1; i < params.dims(); ++i) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
inner_size *= params.dim_size(i);
|
||||
}
|
||||
@ -141,60 +148,29 @@ class GatherOp : public OpKernel {
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
||||
if (N == 0) return;
|
||||
|
||||
int64 bad_i = -1;
|
||||
auto indices_flat = indices.flat<Index>();
|
||||
if (batch_dims_ > 0) {
|
||||
// TODO(virimia): Switch to transpose / gather with axis=0 / transpose
|
||||
// on GPU, to avoid launching a lot of small kernels.
|
||||
auto params_flat = params.shaped<T, 4>(
|
||||
{batch_size, outer_size, gather_dim_size, inner_size});
|
||||
auto out_flat = out->shaped<T, 4>(
|
||||
{batch_size, outer_size, N / batch_size, inner_size});
|
||||
|
||||
// To avoid copying params (by transposing), run gather for each batch.
|
||||
int64 batch_size = 1;
|
||||
for (int i = 0; i < batch_dims_; ++i) {
|
||||
batch_size *= params.dim_size(i);
|
||||
}
|
||||
outer_size /= batch_size;
|
||||
auto batched_params =
|
||||
params.shaped<T, 2>({batch_size, params.NumElements() / batch_size});
|
||||
auto batched_indices =
|
||||
indices.shaped<Index, 2>({batch_size, N / batch_size});
|
||||
auto batched_out =
|
||||
out->shaped<T, 2>({batch_size, out->NumElements() / batch_size});
|
||||
|
||||
// TODO(virimia): Investigate the best performance, when the number of
|
||||
// batches is large, between parallel vs sequential runs.
|
||||
for (int64 batch = 0; batch < batch_size; ++batch) {
|
||||
auto params_flat = typename TTypes<T, 3>::ConstTensor(
|
||||
&batched_params(batch, 0), static_cast<IndexType>(outer_size),
|
||||
static_cast<IndexType>(gather_dim_size),
|
||||
static_cast<IndexType>(inner_size));
|
||||
auto indices_flat = typename TTypes<Index>::ConstFlat(
|
||||
&batched_indices(batch, 0), batched_indices.dimension(1));
|
||||
auto out_flat = typename TTypes<T, 3>::Tensor(
|
||||
&batched_out(batch, 0), static_cast<IndexType>(outer_size),
|
||||
static_cast<IndexType>(N), static_cast<IndexType>(inner_size));
|
||||
|
||||
functor::GatherFunctor<Device, T, Index> functor;
|
||||
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
||||
}
|
||||
functor::GatherFunctorBatched<Device, T, Index> functor;
|
||||
bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
} else {
|
||||
auto params_flat =
|
||||
params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
|
||||
auto indices_flat = indices.flat<Index>();
|
||||
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
|
||||
|
||||
functor::GatherFunctor<Device, T, Index> functor;
|
||||
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
||||
bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
}
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -343,7 +343,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 6, 11):
|
||||
with compat.forward_compatibility_horizon(2019, 8, 11):
|
||||
result = array_ops.gather(
|
||||
params, indices, axis=axis, batch_dims=batch_dims)
|
||||
|
||||
@ -443,7 +443,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(output_shape, result.shape.as_list())
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 6, 11):
|
||||
with compat.forward_compatibility_horizon(2019, 8, 11):
|
||||
result = array_ops.gather(
|
||||
params, indices, axis=axis, batch_dims=batch_dims)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user