From db58c826ae73aafe785f501da0bad5818ac7eaea Mon Sep 17 00:00:00 2001 From: DavidNorman Date: Tue, 13 Jun 2017 07:34:44 +0100 Subject: [PATCH] Use only positive strides in the XLA stride op --- .../compiler/tf2xla/kernels/split_op.cc | 2 +- .../tf2xla/kernels/strided_slice_op.cc | 26 +++++++++++++++++-- .../compiler/xla/service/shape_inference.cc | 8 +----- .../xla/service/shape_inference_test.cc | 9 ------- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 5efc829ee56..42bde900422 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel { std::vector begin(input_shape.dims(), 0); auto dim_sizes = input_shape.dim_sizes(); std::vector limits(dim_sizes.begin(), dim_sizes.end()); - std::vector strides(input_shape.dims(), 0); + std::vector strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); int slice_size = split_sizes_vec[i]; diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 9f0f4f4a4bc..9eb68998310 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -75,8 +75,30 @@ class StridedSliceOp : public XlaOpKernel { &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); - xla::ComputationDataHandle slice = - ctx->builder()->Slice(ctx->Input(0), begin, end, strides); + gtl::InlinedVector dimensions_to_reverse; + gtl::InlinedVector slice_begin, slice_end, slice_strides; + + for (int i = 0; i < begin.size(); ++i) { + if (strides[i] > 0) { + slice_begin.push_back(begin[i]); + slice_end.push_back(end[i]); + slice_strides.push_back(strides[i]); + } else { + // Negative stride: swap begin and end, add 1 because the interval + // is semi-open, and mark the dimension to be reversed. + slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); + slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); + slice_strides.push_back(-strides[i]); + dimensions_to_reverse.push_back(i); + } + } + + xla::ComputationDataHandle slice = ctx->Input(0); + if (!dimensions_to_reverse.empty()) { + slice = ctx->builder()->Rev(slice, dimensions_to_reverse); + } + + slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 87f278b5056..d6436cf988d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1056,13 +1056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } sizes.push_back((limit_index - start_index + stride - 1) / stride); } else { - if (start_index < limit_index) { - return InvalidArgument( - "limit index (%lld) must be less than or equal to " - "start index (%lld) in slice with negative stride", - limit_index, start_index); - } - sizes.push_back((start_index - limit_index - stride - 1) / -stride); + return InvalidArgument("Negative strides not supported"); } } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 31ee1c26873..8c731ae2976 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -697,15 +697,6 @@ TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); } -TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithNegativeStrides) { - Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = - ShapeInference::InferSliceShape(matrix_shape, {64, 64}, {32, 0}, {-1, -4}); - ASSERT_IS_OK(inferred_status.status()); - Shape inferred = inferred_status.ValueOrDie(); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 16}), inferred)); -} - TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status =