From 85504f9555ca804874c53d03bc7b145a7933b570 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Tue, 5 Jan 2021 17:44:46 -0800 Subject: [PATCH] Fix error when end_mask and shrink_mask are set at the same axis PiperOrigin-RevId: 350255382 Change-Id: I1ac180e02a22b62570fe4491fc1e08c6e8fda1de --- .../lite/kernels/internal/strided_slice_logic.h | 2 +- tensorflow/lite/kernels/strided_slice_test.cc | 12 ++++++++++++ tensorflow/lite/testing/op_tests/strided_slice.py | 7 ++++--- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h index 3d91fbdb8e2..bfe84050dca 100644 --- a/tensorflow/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h @@ -140,7 +140,7 @@ inline int StopForAxis(const tflite::StridedSliceParams& params, // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has // already been adjusted for negative indices. if (shrink_axis) { - stop = start_for_axis + 1; + return start_for_axis + 1; } // end_mask override diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 98521b889f9..ef50d29991e 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -745,5 +745,17 @@ TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) { EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"})); } +TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis_Endmask_AtSameAxis) { + StridedSliceOpModel m({2, 2}, {2}, {2}, {2}, 1, 1, 0, 0, 1); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({0, -1}); + m.SetEnd({0, 0}); + m.SetStrides({1, -1}); + + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/strided_slice.py b/tensorflow/lite/testing/op_tests/strided_slice.py index 8668e139f34..daf5449e54f 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice.py +++ b/tensorflow/lite/testing/op_tests/strided_slice.py @@ -63,7 +63,8 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0): end, strides, begin_mask=parameters["begin_mask"], - end_mask=parameters["end_mask"]) + end_mask=parameters["end_mask"], + shrink_axis_mask=parameters["shrink_axis_mask"]) return tensors, [out] def build_inputs(parameters, sess, inputs, outputs): @@ -241,12 +242,12 @@ def make_strided_slice_tests(options): "strides": [[2, 1, 3, 1]], "begin_mask": [8], "end_mask": [3], - "shrink_axis_mask": [None, -1], + "shrink_axis_mask": [None], "constant_indices": [True, False], "fully_quantize": [False], } ] - _make_strided_slice_tests(options, test_parameters, expected_tf_failures=2) + _make_strided_slice_tests(options, test_parameters, expected_tf_failures=29) @register_make_test_function()