[TF2XLA] Support scalar index slicing.

Support scalar index slicing by using inferred output shape.

PiperOrigin-RevId: 268100055
This commit is contained in:
Yunxing Dai 2019-09-09 16:09:25 -07:00 committed by TensorFlower Gardener
parent 44e12c9514
commit 37151067a8
2 changed files with 136 additions and 37 deletions

View File

@ -22,6 +22,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@ -138,6 +139,22 @@ class StridedSliceTest(xla_test.XLATestCase):
self.assertAllEqual([2, 4], result)
def test1DDynamic(self):
for dtype in self.numeric_types:
with self.session():
i = array_ops.placeholder(dtype, shape=[10])
begin = array_ops.placeholder(dtypes.int32, shape=[1])
with self.test_scope():
end = math_ops.add(begin, [1])
o = array_ops.strided_slice(i, begin, end, [1])
params = {
i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
begin: [0]
}
result = o.eval(feed_dict=params)
self.assertAllEqual([0], result)
def test1DNegativeStride(self):
for dtype in self.numeric_types:
with self.session():
@ -179,6 +196,22 @@ class StridedSliceTest(xla_test.XLATestCase):
self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape)
def test2DFullSlice(self):
for dtype in self.numeric_types:
with self.session():
with self.test_scope():
i = array_ops.placeholder(dtype, shape=[2, 4])
begin = array_ops.placeholder(dtypes.int32, shape=[2])
end = math_ops.add(begin, [1, 1])
o = array_ops.strided_slice(i, begin, end, [1, 1])
params = {
i: [[0, 1, 2, 3], [4, 5, 6, 7]],
begin: [1, 1]
}
result = o.eval(feed_dict=params)
self.assertAllEqual([[5]], result)
def test3D(self):
for dtype in self.numeric_types:
with self.session():

View File

@ -14,12 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/strided_slice_op.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
@ -44,60 +46,124 @@ class StridedSliceOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape begin_shape = ctx->InputShape("begin");
OP_REQUIRES(
ctx, begin_shape.dims() == 1,
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
TensorShape final_shape;
absl::InlinedVector<int64, 4> begin;
absl::InlinedVector<int64, 4> end;
absl::InlinedVector<int64, 4> strides;
xla::Literal begin_literal, end_literal, strides_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok();
bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok();
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
Tensor begin_tensor, end_tensor, strides_tensor;
OP_REQUIRES_OK(
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
OP_REQUIRES_OK(ctx,
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
if (begin_is_constant) {
OP_REQUIRES_OK(
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
}
if (end_is_constant) {
OP_REQUIRES_OK(
ctx, LiteralToHostTensor(end_literal, index_type_, &end_tensor));
}
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
TensorShape dummy_processing_shape;
TensorShape final_shape;
PartialTensorShape dummy_processing_shape, partial_final_shape;
bool dummy = false;
OP_REQUIRES_OK(ctx,
ValidateStridedSliceOp(
&begin_tensor, &end_tensor, strides_tensor, input_shape,
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides));
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
begin_is_constant ? &begin_tensor : nullptr,
end_is_constant ? &end_tensor : nullptr,
strides_tensor, input_shape, begin_mask_, end_mask_,
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&dummy_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides));
absl::InlinedVector<int64, 4> dimensions_to_reverse;
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
for (int i = 0; i < begin.size(); ++i) {
if (strides[i] > 0) {
slice_begin.push_back(begin[i]);
slice_end.push_back(std::max(end[i], begin[i]));
slice_strides.push_back(strides[i]);
} else {
// Negative stride: swap begin and end, add 1 because the interval
// is semi-open, and mark the dimension to be reversed.
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1,
input_shape.dim_size(i) - begin[i] - 1));
slice_strides.push_back(-strides[i]);
dimensions_to_reverse.push_back(i);
}
}
OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
errors::InvalidArgument(
"XLA can't deduce compile time constant output "
"shape for strided slice: ",
partial_final_shape.DebugString(),
", output shape must be a compile-time constant"));
xla::XlaOp slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
slice = xla::Rev(slice, dimensions_to_reverse);
if (begin_is_constant && end_is_constant) {
absl::InlinedVector<int64, 4> dimensions_to_reverse;
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
for (int i = 0; i < begin.size(); ++i) {
if (strides[i] > 0) {
slice_begin.push_back(begin[i]);
slice_end.push_back(std::max(end[i], begin[i]));
slice_strides.push_back(strides[i]);
} else {
// Negative stride: swap begin and end, add 1 because the interval
// is semi-open, and mark the dimension to be reversed.
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1,
input_shape.dim_size(i) - begin[i] - 1));
slice_strides.push_back(-strides[i]);
dimensions_to_reverse.push_back(i);
}
}
if (!dimensions_to_reverse.empty()) {
slice = xla::Rev(slice, dimensions_to_reverse);
}
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
} else {
// When output shape is fully defined, it must be a size one slice:
//
// 1. The number of output elements has to equal to number of input
// elements that are sliced.
// 2. The stride of the slice dimensions must be exact one.
int64 output_elements = final_shape.num_elements();
int64 input_elements_sliced = 1;
int64 slicing_dim_size = begin_shape.dim_size(0);
// We only support slicing major dimensions, so minor dimensions after
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
input_elements_sliced *= input_shape.dim_size(d);
}
OP_REQUIRES(
ctx, output_elements == input_elements_sliced,
errors::InvalidArgument(
"The number of output elements ", output_elements,
" has to equal to number of input elements that are sliced ",
input_elements_sliced, " when input indices are not constant."));
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
OP_REQUIRES(
ctx, strides[i] == 1,
errors::InvalidArgument(
"Strides have to be one when inputs are not constant."));
}
// When inputs are not compile time constants, shape inference can only
// inference size 1 slice.
std::vector<int64> slice_sizes(slicing_dim_size, 1);
std::vector<xla::XlaOp> start_indices;
for (int64 d = 0; d < slicing_dim_size; ++d) {
auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1});
// Convert index to scalar.
start_indices.push_back(xla::Reshape(index, {}));
}
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
// For non-slice dims, naturally we get the full slice starting from 0.
slice_sizes.push_back(input_shape.dim_size(d));
start_indices.push_back(
xla::Zero(ctx->builder(), ctx->InputXlaType("begin")));
}
std::vector<int64> output_shape_dim_sizes;
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
}
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}