The C++ kernel of gather op supports batch dimensions.

The XLA kernel of Gather op supports batch dimensions.

The C++ implementation of gather now takes a batch_dims argument and works in the same way as ResourceVariable's gather kernel.
By default, batch_dims is set to 0, for compatibility with existing code.

PiperOrigin-RevId: 248708783
This commit is contained in:
A. Unique TensorFlower 2019-05-17 06:29:21 -07:00 committed by TensorFlower Gardener
parent f9a4227ae5
commit 41fb815c89
9 changed files with 193 additions and 23 deletions

View File

@ -195,6 +195,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:training_ops",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@ -20,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
class GatherOp : public XlaOpKernel {
public:
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
// Set batch_dims_ to 0 if the attribute does not exist.
if (context->HasAttr("batch_dims")) {
OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
} else {
batch_dims_ = 0;
}
}
void Compile(XlaOpKernelContext* context) override {
xla::XlaBuilder* builder = context->builder();
auto input = context->Input(0);
auto input_shape = context->InputShape(0);
auto indices = context->Input(1);
auto indices_shape = context->InputShape(1);
int64 axis = 0;
absl::optional<int64> axis;
if (context->num_inputs() == 3) {
const TensorShape axis_shape = context->InputShape(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel {
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
errors::InvalidArgument("axis must be int32 or int64"));
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
int64 axis_input;
OP_REQUIRES_OK(context,
context->ConstantInputAsIntScalar(2, &axis_input));
const auto params_dims = input_shape.dims();
OP_REQUIRES(
context, -params_dims <= axis && axis < params_dims,
errors::InvalidArgument("Expected axis in the range [", -params_dims,
", ", params_dims, "), but got ", axis));
if (axis < 0) {
axis += params_dims;
OP_REQUIRES(context,
-params_dims <= axis_input && axis_input < params_dims,
errors::InvalidArgument("Expected axis in the range [",
-params_dims, ", ", params_dims,
"), but got ", axis_input));
if (axis_input < 0) {
axis_input += params_dims;
}
axis = axis_input;
}
if (batch_dims_ != 0) {
if (batch_dims_ < 0) {
batch_dims_ = indices_shape.dims() + batch_dims_;
}
axis = axis.value_or(batch_dims_);
OP_REQUIRES(context,
batch_dims_ >= -indices_shape.dims() &&
batch_dims_ < indices_shape.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices_shape.dims(), ", ",
indices_shape.dims(), "), but got ",
batch_dims_));
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than rank(input) (",
input_shape.dims(), ")."));
OP_REQUIRES(context, *axis >= batch_dims_,
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than or equal to ",
"axis (", *axis, ")."));
}
axis = axis.value_or(0);
DataType index_type = input_type(1);
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("indices must be int32 or int64"));
xla::XlaOp gather;
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, axis,
/*indices_are_nd=*/false, input_type(0), index_type,
builder, &gather));
if (batch_dims_ > 0) {
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
} else {
// XlaGather() manages degenerate cases, like empty-indices, which are
// error conditions and caught above if batch_dims is not 0.
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
/*indices_are_nd=*/false, input_type(0),
index_type, context->builder(), &gather));
}
context->SetOutput(0, gather);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
// The number of batch dimensions, as passed in the batch_dims attribute.
// It must be less than rank(indices).
int32 batch_dims_ = 0;
};
REGISTER_XLA_OP(Name("Gather"), GatherOp);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/gather_functor.h"
@ -39,7 +40,14 @@ class GatherOp : public OpKernel {
// we have the framework do some sort of integer promotion
// automatically, or should that be something that users have to
// do explicitly with a conversion operator in the graph?
explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) {}
explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) {
// Set batch_dims_ to 0 if the attribute does not exist.
if (c->HasAttr("batch_dims")) {
OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
} else {
batch_dims_ = 0;
}
}
void Compute(OpKernelContext* c) override {
const Tensor& params = c->input(0);
@ -51,7 +59,9 @@ class GatherOp : public OpKernel {
// GatherV2 added an axis argument. For backwards compatibility with Gather,
// fall back to axis 0 if the op does not have an axis input.
int64 axis = 0;
bool axis_is_set = false; // Indicates whether the axis argument was set.
if (c->num_inputs() == 3) {
axis_is_set = true;
const Tensor& axis_tensor = c->input(2);
OP_REQUIRES(c, TensorShapeUtils::IsScalar(axis_tensor.shape()),
errors::InvalidArgument("axis must be scalar"));
@ -70,12 +80,37 @@ class GatherOp : public OpKernel {
c, axis >= -params.dims() && axis < params.dims(),
errors::InvalidArgument("Expected axis in the range [", -params.dims(),
", ", params.dims(), "), but got ", axis));
if (axis < 0) {
axis = params.dims() + axis;
}
if (batch_dims_ != 0) {
if (batch_dims_ < 0) {
batch_dims_ = indices.dims() + batch_dims_;
}
if (!axis_is_set) axis = batch_dims_;
OP_REQUIRES(
c, batch_dims_ >= -indices.dims() && batch_dims_ < indices.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices.dims(), ", ", indices.dims(),
"), but got ", batch_dims_));
OP_REQUIRES(c, batch_dims_ < params.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than rank(params) (",
params.dims(), ")."));
OP_REQUIRES(c, axis >= batch_dims_,
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than or equal to ",
"axis (", axis, ")."));
}
// Check that we have enough index space
const int64 gather_dim_size = params.dim_size(axis);
int64 gather_dim_size = params.dim_size(axis);
const int64 N = indices.NumElements();
OP_REQUIRES(
c, gather_dim_size <= std::numeric_limits<Index>::max(),
@ -84,7 +119,7 @@ class GatherOp : public OpKernel {
" indexing: ", gather_dim_size, " > ",
std::numeric_limits<Index>::max()));
// The result shape is params.shape[0:axis] + indices.shape +
// The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
// params.shape[axis + 1:].
TensorShape result_shape;
int64 outer_size = 1;
@ -93,7 +128,9 @@ class GatherOp : public OpKernel {
result_shape.AddDim(params.dim_size(i));
outer_size *= params.dim_size(i);
}
result_shape.AppendShape(indices.shape());
for (int i = batch_dims_; i < indices.dims(); ++i) {
result_shape.AddDim(indices.dim_size(i));
}
for (int i = axis + 1; i < params.dims(); i++) {
result_shape.AddDim(params.dim_size(i));
inner_size *= params.dim_size(i);
@ -101,14 +138,53 @@ class GatherOp : public OpKernel {
Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
if (N > 0 && outer_size > 0 && inner_size > 0) {
if (N == 0) return;
if (batch_dims_ > 0) {
// TODO(virimia): Switch to transpose / gather with axis=0 / transpose
// on GPU, to avoid launching a lot of small kernels.
// To avoid copying params (by transposing), run gather for each batch.
int64 batch_size = 1;
for (int i = 0; i < batch_dims_; ++i) {
batch_size *= params.dim_size(i);
}
outer_size /= batch_size;
auto batched_params =
params.shaped<T, 2>({batch_size, params.NumElements() / batch_size});
auto batched_indices =
indices.shaped<Index, 2>({batch_size, N / batch_size});
auto batched_out =
out->shaped<T, 2>({batch_size, out->NumElements() / batch_size});
// TODO(virimia): Investigate the best performance, when the number of
// batches is large, between parallel vs sequential runs.
for (int64 batch = 0; batch < batch_size; ++batch) {
auto params_flat = typename TTypes<T, 3>::ConstTensor(
&batched_params(batch, 0),
{outer_size, gather_dim_size, inner_size});
auto indices_flat = typename TTypes<Index>::ConstFlat(
&batched_indices(batch, 0), {batched_indices.dimension(1)});
auto out_flat = typename TTypes<T, 3>::Tensor(
&batched_out(batch, 0), {outer_size, N, inner_size});
functor::GatherFunctor<Device, T, Index> functor;
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
}
} else {
auto params_flat =
params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
auto indices_flat = indices.flat<Index>();
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
functor::GatherFunctor<Device, T, Index> functor;
int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
OP_REQUIRES(
c, bad_i < 0,
@ -117,6 +193,11 @@ class GatherOp : public OpKernel {
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
}
}
private:
// The number of batch dimensions, as passed in the batch_dims attribute.
// It must be less than rank(indices).
int32 batch_dims_ = 0;
};
#define REGISTER_GATHER_FULL(dev, type, index_type) \

View File

@ -1113,6 +1113,7 @@ REGISTER_OP("GatherV2")
.Input("params: Tparams")
.Input("indices: Tindices")
.Input("axis: Taxis")
.Attr("batch_dims: int = 0")
.Output("output: Tparams")
.Attr("Tparams: type")
.Attr("Tindices: {int32,int64}")
@ -1151,13 +1152,24 @@ REGISTER_OP("GatherV2")
TF_RETURN_IF_ERROR(c->WithRankAtLeast(
params_shape, axis < 0 ? -axis : axis + 1, &unused));
// Note, batch_dims can be negative.
int32 batch_dims;
TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(
params_shape, batch_dims < 0 ? -batch_dims : batch_dims + 1,
&unused));
ShapeHandle params_outer_subshape;
TF_RETURN_IF_ERROR(
c->Subshape(params_shape, 0, axis, &params_outer_subshape));
ShapeHandle indices_inner_subshape;
TF_RETURN_IF_ERROR(
c->Subshape(indices_shape, batch_dims, &indices_inner_subshape));
ShapeHandle out;
TF_RETURN_IF_ERROR(
c->Concatenate(params_outer_subshape, indices_shape, &out));
c->Concatenate(params_outer_subshape, indices_inner_subshape, &out));
// Slice from axis + 1 to the end of params_shape to collect the inner
// dimensions of the result. Special case -1 here since -1 + 1 wraps, and

View File

@ -293,6 +293,7 @@ TEST(ArrayOpsTest, Gather_ShapeFn) {
TEST(ArrayOpsTest, GatherV2_ShapeFn) {
ShapeInferenceTestOp op("GatherV2");
AddNodeAttr("batch_dims", 0, &op.node_def);
// Tests when axis is unknown.
INFER_OK(op, "?;?;?", "?");

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -35,6 +36,8 @@ from tensorflow.python.platform import test
_TEST_TYPES = (dtypes.int64, dtypes.float32,
dtypes.complex64, dtypes.complex128)
# TODO(virimia): Add a benchmark for gather_v2, with batch_dims and axis set.
class GatherTest(test.TestCase, parameterized.TestCase):
@ -340,6 +343,12 @@ class GatherTest(test.TestCase, parameterized.TestCase):
result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
self.assertAllEqual(expected, result)
with compat.forward_compatibility_horizon(2019, 6, 11):
result = array_ops.gather(
params, indices, axis=axis, batch_dims=batch_dims)
self.assertAllEqual(expected, result)
@parameterized.parameters([
dict(
params_shape=[2, 3, 4, 5, 6, 7],
@ -434,6 +443,13 @@ class GatherTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(output_shape, result.shape.as_list())
self.assertAllEqual(expected, result)
with compat.forward_compatibility_horizon(2019, 6, 11):
result = array_ops.gather(
params, indices, axis=axis, batch_dims=batch_dims)
self.assertAllEqual(output_shape, result.shape.as_list())
self.assertAllEqual(expected, result)
def _batchNumpyGather(self, params, indices, axis, batch_dims):
"""Performs a batch gather by making recursive calls to np.take().

View File

@ -3443,6 +3443,12 @@ def gather(params,
A `Tensor`. Has the same type as `params`.
"""
del validate_indices
if compat.forward_compatible(2019, 6, 10):
if axis is None:
axis = batch_dims
return gen_array_ops.gather_v2(
params, indices, axis, batch_dims=batch_dims, name=name)
if batch_dims != 0:
with ops.name_scope(name, "Gather", [params, indices, axis]):
return _batch_gather(params, indices, batch_dims, axis)

View File

@ -1430,7 +1430,7 @@ tf_module {
}
member_method {
name: "GatherV2"
argspec: "args=[\'params\', \'indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'params\', \'indices\', \'axis\', \'batch_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
}
member_method {
name: "GenerateVocabRemapping"

View File

@ -1430,7 +1430,7 @@ tf_module {
}
member_method {
name: "GatherV2"
argspec: "args=[\'params\', \'indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'params\', \'indices\', \'axis\', \'batch_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
}
member_method {
name: "GenerateVocabRemapping"