[TF2XLA] Support scalar index slicing.
Support scalar index slicing by using inferred output shape. PiperOrigin-RevId: 268100055
This commit is contained in:
parent
44e12c9514
commit
37151067a8
@ -22,6 +22,7 @@ from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -138,6 +139,22 @@ class StridedSliceTest(xla_test.XLATestCase):
|
||||
|
||||
self.assertAllEqual([2, 4], result)
|
||||
|
||||
def test1DDynamic(self):
|
||||
for dtype in self.numeric_types:
|
||||
with self.session():
|
||||
i = array_ops.placeholder(dtype, shape=[10])
|
||||
begin = array_ops.placeholder(dtypes.int32, shape=[1])
|
||||
with self.test_scope():
|
||||
end = math_ops.add(begin, [1])
|
||||
o = array_ops.strided_slice(i, begin, end, [1])
|
||||
params = {
|
||||
i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
begin: [0]
|
||||
}
|
||||
result = o.eval(feed_dict=params)
|
||||
|
||||
self.assertAllEqual([0], result)
|
||||
|
||||
def test1DNegativeStride(self):
|
||||
for dtype in self.numeric_types:
|
||||
with self.session():
|
||||
@ -179,6 +196,22 @@ class StridedSliceTest(xla_test.XLATestCase):
|
||||
|
||||
self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape)
|
||||
|
||||
def test2DFullSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
with self.session():
|
||||
with self.test_scope():
|
||||
i = array_ops.placeholder(dtype, shape=[2, 4])
|
||||
begin = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||
end = math_ops.add(begin, [1, 1])
|
||||
o = array_ops.strided_slice(i, begin, end, [1, 1])
|
||||
params = {
|
||||
i: [[0, 1, 2, 3], [4, 5, 6, 7]],
|
||||
begin: [1, 1]
|
||||
}
|
||||
result = o.eval(feed_dict=params)
|
||||
|
||||
self.assertAllEqual([[5]], result)
|
||||
|
||||
def test3D(self):
|
||||
for dtype in self.numeric_types:
|
||||
with self.session():
|
||||
|
@ -14,12 +14,14 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/strided_slice_op.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/ops_util.h"
|
||||
@ -44,60 +46,124 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape begin_shape = ctx->InputShape("begin");
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, begin_shape.dims() == 1,
|
||||
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
|
||||
|
||||
TensorShape final_shape;
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
absl::InlinedVector<int64, 4> end;
|
||||
absl::InlinedVector<int64, 4> strides;
|
||||
|
||||
xla::Literal begin_literal, end_literal, strides_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
||||
bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok();
|
||||
bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok();
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
||||
|
||||
Tensor begin_tensor, end_tensor, strides_tensor;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
||||
if (begin_is_constant) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
||||
}
|
||||
if (end_is_constant) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LiteralToHostTensor(end_literal, index_type_, &end_tensor));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||
&strides_tensor));
|
||||
|
||||
TensorShape dummy_processing_shape;
|
||||
TensorShape final_shape;
|
||||
PartialTensorShape dummy_processing_shape, partial_final_shape;
|
||||
bool dummy = false;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor, input_shape,
|
||||
begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
||||
begin_is_constant ? &begin_tensor : nullptr,
|
||||
end_is_constant ? &end_tensor : nullptr,
|
||||
strides_tensor, input_shape, begin_mask_, end_mask_,
|
||||
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&dummy_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
|
||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
if (strides[i] > 0) {
|
||||
slice_begin.push_back(begin[i]);
|
||||
slice_end.push_back(std::max(end[i], begin[i]));
|
||||
slice_strides.push_back(strides[i]);
|
||||
} else {
|
||||
// Negative stride: swap begin and end, add 1 because the interval
|
||||
// is semi-open, and mark the dimension to be reversed.
|
||||
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
|
||||
slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1,
|
||||
input_shape.dim_size(i) - begin[i] - 1));
|
||||
slice_strides.push_back(-strides[i]);
|
||||
dimensions_to_reverse.push_back(i);
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||
errors::InvalidArgument(
|
||||
"XLA can't deduce compile time constant output "
|
||||
"shape for strided slice: ",
|
||||
partial_final_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
|
||||
xla::XlaOp slice = ctx->Input(0);
|
||||
if (!dimensions_to_reverse.empty()) {
|
||||
slice = xla::Rev(slice, dimensions_to_reverse);
|
||||
if (begin_is_constant && end_is_constant) {
|
||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
if (strides[i] > 0) {
|
||||
slice_begin.push_back(begin[i]);
|
||||
slice_end.push_back(std::max(end[i], begin[i]));
|
||||
slice_strides.push_back(strides[i]);
|
||||
} else {
|
||||
// Negative stride: swap begin and end, add 1 because the interval
|
||||
// is semi-open, and mark the dimension to be reversed.
|
||||
slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
|
||||
slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1,
|
||||
input_shape.dim_size(i) - begin[i] - 1));
|
||||
slice_strides.push_back(-strides[i]);
|
||||
dimensions_to_reverse.push_back(i);
|
||||
}
|
||||
}
|
||||
if (!dimensions_to_reverse.empty()) {
|
||||
slice = xla::Rev(slice, dimensions_to_reverse);
|
||||
}
|
||||
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
|
||||
} else {
|
||||
// When output shape is fully defined, it must be a size one slice:
|
||||
//
|
||||
// 1. The number of output elements has to equal to number of input
|
||||
// elements that are sliced.
|
||||
// 2. The stride of the slice dimensions must be exact one.
|
||||
int64 output_elements = final_shape.num_elements();
|
||||
|
||||
int64 input_elements_sliced = 1;
|
||||
int64 slicing_dim_size = begin_shape.dim_size(0);
|
||||
// We only support slicing major dimensions, so minor dimensions after
|
||||
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
|
||||
input_elements_sliced *= input_shape.dim_size(d);
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, output_elements == input_elements_sliced,
|
||||
errors::InvalidArgument(
|
||||
"The number of output elements ", output_elements,
|
||||
" has to equal to number of input elements that are sliced ",
|
||||
input_elements_sliced, " when input indices are not constant."));
|
||||
|
||||
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1,
|
||||
errors::InvalidArgument(
|
||||
"Strides have to be one when inputs are not constant."));
|
||||
}
|
||||
|
||||
// When inputs are not compile time constants, shape inference can only
|
||||
// inference size 1 slice.
|
||||
std::vector<int64> slice_sizes(slicing_dim_size, 1);
|
||||
std::vector<xla::XlaOp> start_indices;
|
||||
for (int64 d = 0; d < slicing_dim_size; ++d) {
|
||||
auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1});
|
||||
// Convert index to scalar.
|
||||
start_indices.push_back(xla::Reshape(index, {}));
|
||||
}
|
||||
|
||||
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
|
||||
// For non-slice dims, naturally we get the full slice starting from 0.
|
||||
slice_sizes.push_back(input_shape.dim_size(d));
|
||||
start_indices.push_back(
|
||||
xla::Zero(ctx->builder(), ctx->InputXlaType("begin")));
|
||||
}
|
||||
|
||||
std::vector<int64> output_shape_dim_sizes;
|
||||
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
|
||||
}
|
||||
|
||||
slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
|
||||
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user