From 65b5275b552578e3f1a8991612300da282213f71 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Thu, 17 Sep 2020 08:39:26 -0700 Subject: [PATCH] Handle edge case in TFLite StridedSlice: the length of begin is different from the rank of input PiperOrigin-RevId: 332241089 Change-Id: I41b1436d78db87ac78c60e2d72bc8c9016e0106e --- tensorflow/lite/kernels/strided_slice.cc | 24 +++++++++++++------ tensorflow/lite/kernels/strided_slice_test.cc | 23 ++++++++++++++++++ .../lite/testing/op_tests/strided_slice.py | 8 +++---- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index 83221cd4a3d..d10e99c1997 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -71,17 +71,27 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) { op_params.stop_indices_count = op_context->dims; op_params.strides_count = op_context->dims; - for (int i = 0; i < op_context->dims; ++i) { - op_params.start_indices[i] = GetTensorData(op_context->begin)[i]; - op_params.stop_indices[i] = GetTensorData(op_context->end)[i]; - op_params.strides[i] = GetTensorData(op_context->strides)[i]; - } - op_params.begin_mask = op_context->params->begin_mask; op_params.ellipsis_mask = 0; op_params.end_mask = op_context->params->end_mask; op_params.new_axis_mask = 0; op_params.shrink_axis_mask = op_context->params->shrink_axis_mask; + + int begin_count = GetTensorShape(op_context->begin).Dims(0); + for (int i = 0; i < begin_count; ++i) { + op_params.start_indices[i] = GetTensorData(op_context->begin)[i]; + op_params.stop_indices[i] = GetTensorData(op_context->end)[i]; + op_params.strides[i] = GetTensorData(op_context->strides)[i]; + } + + // If the length of begin and end smaller than number of input dims, set the + // mask bit of begin and end for that index. + for (int i = begin_count; i < op_context->dims; ++i) { + op_params.start_indices[i] = op_params.stop_indices[i] = 0; + op_params.strides[i] = 1; + op_params.begin_mask |= (1 << i); + op_params.end_mask |= (1 << i); + } return op_params; } @@ -95,7 +105,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, RuntimeShape input_shape = GetTensorShape(op_context->input); for (int idx = op_context->dims - 1; idx >= 0; --idx) { - int32_t stride = GetTensorData(op_context->strides)[idx]; + int32_t stride = op_params.strides[idx]; TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); int32_t begin = diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 5f625d3f201..f174c236d98 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -649,5 +649,28 @@ TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); } + +TYPED_TEST(StridedSliceOpTest, In3D_SmallBegin) { + StridedSliceOpModel m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0}); + m.SetEnd({1}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) { + StridedSliceOpModel m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0}); + m.SetEnd({1}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/strided_slice.py b/tensorflow/lite/testing/op_tests/strided_slice.py index bc1b0115c24..4f2cbd9b64a 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice.py +++ b/tensorflow/lite/testing/op_tests/strided_slice.py @@ -43,17 +43,17 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0): begin = tf.compat.v1.placeholder( dtype=parameters["index_type"], name="begin", - shape=[len(parameters["input_shape"])]) + shape=[len(parameters["begin"])]) end = tf.compat.v1.placeholder( dtype=parameters["index_type"], name="end", - shape=[len(parameters["input_shape"])]) + shape=[len(parameters["end"])]) strides = None if parameters["strides"] is not None: strides = tf.compat.v1.placeholder( dtype=parameters["index_type"], name="strides", - shape=[len(parameters["input_shape"])]) + shape=[len(parameters["strides"])]) tensors = [input_tensor, begin, end] if strides is not None: tensors.append(strides) @@ -141,7 +141,7 @@ def make_strided_slice_tests(options): "begin_mask": [0], "end_mask": [0], "shrink_axis_mask": [1], - "constant_indices": [True], + "constant_indices": [True, False], "fully_quantize": [False], }, # 2-D