The C++ kernel of gather op supports batch dimensions.
The XLA kernel of Gather op supports batch dimensions. The C++ implementation of gather now takes a batch_dims argument and works in the same way as ResourceVariable's gather kernel. By default, batch_dims is set to 0, for compatibility with existing code. PiperOrigin-RevId: 248708783
This commit is contained in:
parent
f9a4227ae5
commit
41fb815c89
tensorflow
compiler/tf2xla/kernels
core
python
tools/api/golden
@ -195,6 +195,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -20,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -148,15 +152,22 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
|
||||
|
||||
class GatherOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
|
||||
// Set batch_dims_ to 0 if the attribute does not exist.
|
||||
if (context->HasAttr("batch_dims")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
|
||||
} else {
|
||||
batch_dims_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
xla::XlaBuilder* builder = context->builder();
|
||||
auto input = context->Input(0);
|
||||
auto input_shape = context->InputShape(0);
|
||||
auto indices = context->Input(1);
|
||||
auto indices_shape = context->InputShape(1);
|
||||
int64 axis = 0;
|
||||
|
||||
absl::optional<int64> axis;
|
||||
if (context->num_inputs() == 3) {
|
||||
const TensorShape axis_shape = context->InputShape(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
|
||||
@ -165,31 +176,73 @@ class GatherOp : public XlaOpKernel {
|
||||
OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
|
||||
errors::InvalidArgument("axis must be int32 or int64"));
|
||||
|
||||
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
|
||||
int64 axis_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->ConstantInputAsIntScalar(2, &axis_input));
|
||||
|
||||
const auto params_dims = input_shape.dims();
|
||||
OP_REQUIRES(
|
||||
context, -params_dims <= axis && axis < params_dims,
|
||||
errors::InvalidArgument("Expected axis in the range [", -params_dims,
|
||||
", ", params_dims, "), but got ", axis));
|
||||
if (axis < 0) {
|
||||
axis += params_dims;
|
||||
OP_REQUIRES(context,
|
||||
-params_dims <= axis_input && axis_input < params_dims,
|
||||
errors::InvalidArgument("Expected axis in the range [",
|
||||
-params_dims, ", ", params_dims,
|
||||
"), but got ", axis_input));
|
||||
if (axis_input < 0) {
|
||||
axis_input += params_dims;
|
||||
}
|
||||
axis = axis_input;
|
||||
}
|
||||
|
||||
if (batch_dims_ != 0) {
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices_shape.dims() + batch_dims_;
|
||||
}
|
||||
|
||||
axis = axis.value_or(batch_dims_);
|
||||
|
||||
OP_REQUIRES(context,
|
||||
batch_dims_ >= -indices_shape.dims() &&
|
||||
batch_dims_ < indices_shape.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices_shape.dims(), ", ",
|
||||
indices_shape.dims(), "), but got ",
|
||||
batch_dims_));
|
||||
|
||||
OP_REQUIRES(context, batch_dims_ < input_shape.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than rank(input) (",
|
||||
input_shape.dims(), ")."));
|
||||
|
||||
OP_REQUIRES(context, *axis >= batch_dims_,
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than or equal to ",
|
||||
"axis (", *axis, ")."));
|
||||
}
|
||||
|
||||
axis = axis.value_or(0);
|
||||
DataType index_type = input_type(1);
|
||||
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
|
||||
errors::InvalidArgument("indices must be int32 or int64"));
|
||||
|
||||
xla::XlaOp gather;
|
||||
OP_REQUIRES_OK(
|
||||
context, XlaGather(input, input_shape, indices, indices_shape, axis,
|
||||
/*indices_are_nd=*/false, input_type(0), index_type,
|
||||
builder, &gather));
|
||||
if (batch_dims_ > 0) {
|
||||
gather = xla::TorchIndexSelect(input, indices, *axis, batch_dims_);
|
||||
} else {
|
||||
// XlaGather() manages degenerate cases, like empty-indices, which are
|
||||
// error conditions and caught above if batch_dims is not 0.
|
||||
OP_REQUIRES_OK(
|
||||
context, XlaGather(input, input_shape, indices, indices_shape, *axis,
|
||||
/*indices_are_nd=*/false, input_type(0),
|
||||
index_type, context->builder(), &gather));
|
||||
}
|
||||
context->SetOutput(0, gather);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
|
||||
|
||||
// The number of batch dimensions, as passed in the batch_dims attribute.
|
||||
// It must be less than rank(indices).
|
||||
int32 batch_dims_ = 0;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Gather"), GatherOp);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/gather_functor.h"
|
||||
@ -39,7 +40,14 @@ class GatherOp : public OpKernel {
|
||||
// we have the framework do some sort of integer promotion
|
||||
// automatically, or should that be something that users have to
|
||||
// do explicitly with a conversion operator in the graph?
|
||||
explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
// Set batch_dims_ to 0 if the attribute does not exist.
|
||||
if (c->HasAttr("batch_dims")) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
|
||||
} else {
|
||||
batch_dims_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor& params = c->input(0);
|
||||
@ -51,7 +59,9 @@ class GatherOp : public OpKernel {
|
||||
// GatherV2 added an axis argument. For backwards compatibility with Gather,
|
||||
// fall back to axis 0 if the op does not have an axis input.
|
||||
int64 axis = 0;
|
||||
bool axis_is_set = false; // Indicates whether the axis argument was set.
|
||||
if (c->num_inputs() == 3) {
|
||||
axis_is_set = true;
|
||||
const Tensor& axis_tensor = c->input(2);
|
||||
OP_REQUIRES(c, TensorShapeUtils::IsScalar(axis_tensor.shape()),
|
||||
errors::InvalidArgument("axis must be scalar"));
|
||||
@ -70,12 +80,37 @@ class GatherOp : public OpKernel {
|
||||
c, axis >= -params.dims() && axis < params.dims(),
|
||||
errors::InvalidArgument("Expected axis in the range [", -params.dims(),
|
||||
", ", params.dims(), "), but got ", axis));
|
||||
|
||||
if (axis < 0) {
|
||||
axis = params.dims() + axis;
|
||||
}
|
||||
|
||||
if (batch_dims_ != 0) {
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices.dims() + batch_dims_;
|
||||
}
|
||||
|
||||
if (!axis_is_set) axis = batch_dims_;
|
||||
|
||||
OP_REQUIRES(
|
||||
c, batch_dims_ >= -indices.dims() && batch_dims_ < indices.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices.dims(), ", ", indices.dims(),
|
||||
"), but got ", batch_dims_));
|
||||
|
||||
OP_REQUIRES(c, batch_dims_ < params.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than rank(params) (",
|
||||
params.dims(), ")."));
|
||||
|
||||
OP_REQUIRES(c, axis >= batch_dims_,
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than or equal to ",
|
||||
"axis (", axis, ")."));
|
||||
}
|
||||
|
||||
// Check that we have enough index space
|
||||
const int64 gather_dim_size = params.dim_size(axis);
|
||||
int64 gather_dim_size = params.dim_size(axis);
|
||||
const int64 N = indices.NumElements();
|
||||
OP_REQUIRES(
|
||||
c, gather_dim_size <= std::numeric_limits<Index>::max(),
|
||||
@ -84,7 +119,7 @@ class GatherOp : public OpKernel {
|
||||
" indexing: ", gather_dim_size, " > ",
|
||||
std::numeric_limits<Index>::max()));
|
||||
|
||||
// The result shape is params.shape[0:axis] + indices.shape +
|
||||
// The result shape is params.shape[:axis] + indices.shape[batch_dims:] +
|
||||
// params.shape[axis + 1:].
|
||||
TensorShape result_shape;
|
||||
int64 outer_size = 1;
|
||||
@ -93,7 +128,9 @@ class GatherOp : public OpKernel {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
outer_size *= params.dim_size(i);
|
||||
}
|
||||
result_shape.AppendShape(indices.shape());
|
||||
for (int i = batch_dims_; i < indices.dims(); ++i) {
|
||||
result_shape.AddDim(indices.dim_size(i));
|
||||
}
|
||||
for (int i = axis + 1; i < params.dims(); i++) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
inner_size *= params.dim_size(i);
|
||||
@ -101,14 +138,53 @@ class GatherOp : public OpKernel {
|
||||
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
||||
if (N > 0 && outer_size > 0 && inner_size > 0) {
|
||||
if (N == 0) return;
|
||||
|
||||
if (batch_dims_ > 0) {
|
||||
// TODO(virimia): Switch to transpose / gather with axis=0 / transpose
|
||||
// on GPU, to avoid launching a lot of small kernels.
|
||||
|
||||
// To avoid copying params (by transposing), run gather for each batch.
|
||||
int64 batch_size = 1;
|
||||
for (int i = 0; i < batch_dims_; ++i) {
|
||||
batch_size *= params.dim_size(i);
|
||||
}
|
||||
outer_size /= batch_size;
|
||||
auto batched_params =
|
||||
params.shaped<T, 2>({batch_size, params.NumElements() / batch_size});
|
||||
auto batched_indices =
|
||||
indices.shaped<Index, 2>({batch_size, N / batch_size});
|
||||
auto batched_out =
|
||||
out->shaped<T, 2>({batch_size, out->NumElements() / batch_size});
|
||||
|
||||
// TODO(virimia): Investigate the best performance, when the number of
|
||||
// batches is large, between parallel vs sequential runs.
|
||||
for (int64 batch = 0; batch < batch_size; ++batch) {
|
||||
auto params_flat = typename TTypes<T, 3>::ConstTensor(
|
||||
&batched_params(batch, 0),
|
||||
{outer_size, gather_dim_size, inner_size});
|
||||
auto indices_flat = typename TTypes<Index>::ConstFlat(
|
||||
&batched_indices(batch, 0), {batched_indices.dimension(1)});
|
||||
auto out_flat = typename TTypes<T, 3>::Tensor(
|
||||
&batched_out(batch, 0), {outer_size, N, inner_size});
|
||||
|
||||
functor::GatherFunctor<Device, T, Index> functor;
|
||||
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
||||
}
|
||||
} else {
|
||||
auto params_flat =
|
||||
params.shaped<T, 3>({outer_size, gather_dim_size, inner_size});
|
||||
auto indices_flat = indices.flat<Index>();
|
||||
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
|
||||
|
||||
functor::GatherFunctor<Device, T, Index> functor;
|
||||
int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
const int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
|
||||
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
@ -117,6 +193,11 @@ class GatherOp : public OpKernel {
|
||||
indices_flat(bad_i), " is not in [0, ", gather_dim_size, ")"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// The number of batch dimensions, as passed in the batch_dims attribute.
|
||||
// It must be less than rank(indices).
|
||||
int32 batch_dims_ = 0;
|
||||
};
|
||||
|
||||
#define REGISTER_GATHER_FULL(dev, type, index_type) \
|
||||
|
@ -1113,6 +1113,7 @@ REGISTER_OP("GatherV2")
|
||||
.Input("params: Tparams")
|
||||
.Input("indices: Tindices")
|
||||
.Input("axis: Taxis")
|
||||
.Attr("batch_dims: int = 0")
|
||||
.Output("output: Tparams")
|
||||
.Attr("Tparams: type")
|
||||
.Attr("Tindices: {int32,int64}")
|
||||
@ -1151,13 +1152,24 @@ REGISTER_OP("GatherV2")
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(
|
||||
params_shape, axis < 0 ? -axis : axis + 1, &unused));
|
||||
|
||||
// Note, batch_dims can be negative.
|
||||
int32 batch_dims;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(
|
||||
params_shape, batch_dims < 0 ? -batch_dims : batch_dims + 1,
|
||||
&unused));
|
||||
|
||||
ShapeHandle params_outer_subshape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subshape(params_shape, 0, axis, ¶ms_outer_subshape));
|
||||
|
||||
ShapeHandle indices_inner_subshape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subshape(indices_shape, batch_dims, &indices_inner_subshape));
|
||||
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(params_outer_subshape, indices_shape, &out));
|
||||
c->Concatenate(params_outer_subshape, indices_inner_subshape, &out));
|
||||
|
||||
// Slice from axis + 1 to the end of params_shape to collect the inner
|
||||
// dimensions of the result. Special case -1 here since -1 + 1 wraps, and
|
||||
|
@ -293,6 +293,7 @@ TEST(ArrayOpsTest, Gather_ShapeFn) {
|
||||
|
||||
TEST(ArrayOpsTest, GatherV2_ShapeFn) {
|
||||
ShapeInferenceTestOp op("GatherV2");
|
||||
AddNodeAttr("batch_dims", 0, &op.node_def);
|
||||
|
||||
// Tests when axis is unknown.
|
||||
INFER_OK(op, "?;?;?", "?");
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -35,6 +36,8 @@ from tensorflow.python.platform import test
|
||||
_TEST_TYPES = (dtypes.int64, dtypes.float32,
|
||||
dtypes.complex64, dtypes.complex128)
|
||||
|
||||
# TODO(virimia): Add a benchmark for gather_v2, with batch_dims and axis set.
|
||||
|
||||
|
||||
class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@ -340,6 +343,12 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 6, 11):
|
||||
result = array_ops.gather(
|
||||
params, indices, axis=axis, batch_dims=batch_dims)
|
||||
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
params_shape=[2, 3, 4, 5, 6, 7],
|
||||
@ -434,6 +443,13 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(output_shape, result.shape.as_list())
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 6, 11):
|
||||
result = array_ops.gather(
|
||||
params, indices, axis=axis, batch_dims=batch_dims)
|
||||
|
||||
self.assertAllEqual(output_shape, result.shape.as_list())
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def _batchNumpyGather(self, params, indices, axis, batch_dims):
|
||||
"""Performs a batch gather by making recursive calls to np.take().
|
||||
|
||||
|
@ -3443,6 +3443,12 @@ def gather(params,
|
||||
A `Tensor`. Has the same type as `params`.
|
||||
"""
|
||||
del validate_indices
|
||||
if compat.forward_compatible(2019, 6, 10):
|
||||
if axis is None:
|
||||
axis = batch_dims
|
||||
return gen_array_ops.gather_v2(
|
||||
params, indices, axis, batch_dims=batch_dims, name=name)
|
||||
|
||||
if batch_dims != 0:
|
||||
with ops.name_scope(name, "Gather", [params, indices, axis]):
|
||||
return _batch_gather(params, indices, batch_dims, axis)
|
||||
|
@ -1430,7 +1430,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "GatherV2"
|
||||
argspec: "args=[\'params\', \'indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'params\', \'indices\', \'axis\', \'batch_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "GenerateVocabRemapping"
|
||||
|
@ -1430,7 +1430,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "GatherV2"
|
||||
argspec: "args=[\'params\', \'indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'params\', \'indices\', \'axis\', \'batch_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "GenerateVocabRemapping"
|
||||
|
Loading…
Reference in New Issue
Block a user