Handle edge case in TFLite StridedSlice: the length of begin is different from the rank of input

PiperOrigin-RevId: 332241089
Change-Id: I41b1436d78db87ac78c60e2d72bc8c9016e0106e
This commit is contained in:
Thai Nguyen 2020-09-17 08:39:26 -07:00 committed by TensorFlower Gardener
parent df1688cae1
commit 65b5275b55
3 changed files with 44 additions and 11 deletions

View File

@ -71,17 +71,27 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) {
op_params.stop_indices_count = op_context->dims;
op_params.strides_count = op_context->dims;
for (int i = 0; i < op_context->dims; ++i) {
op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
}
op_params.begin_mask = op_context->params->begin_mask;
op_params.ellipsis_mask = 0;
op_params.end_mask = op_context->params->end_mask;
op_params.new_axis_mask = 0;
op_params.shrink_axis_mask = op_context->params->shrink_axis_mask;
int begin_count = GetTensorShape(op_context->begin).Dims(0);
for (int i = 0; i < begin_count; ++i) {
op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
}
// If the length of begin and end smaller than number of input dims, set the
// mask bit of begin and end for that index.
for (int i = begin_count; i < op_context->dims; ++i) {
op_params.start_indices[i] = op_params.stop_indices[i] = 0;
op_params.strides[i] = 1;
op_params.begin_mask |= (1 << i);
op_params.end_mask |= (1 << i);
}
return op_params;
}
@ -95,7 +105,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
RuntimeShape input_shape = GetTensorShape(op_context->input);
for (int idx = op_context->dims - 1; idx >= 0; --idx) {
int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
int32_t stride = op_params.strides[idx];
TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
int32_t begin =

View File

@ -649,5 +649,28 @@ TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4}));
}
TYPED_TEST(StridedSliceOpTest, In3D_SmallBegin) {
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0});
m.SetEnd({1});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) {
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0});
m.SetEnd({1});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
} // namespace
} // namespace tflite

View File

@ -43,17 +43,17 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
begin = tf.compat.v1.placeholder(
dtype=parameters["index_type"],
name="begin",
shape=[len(parameters["input_shape"])])
shape=[len(parameters["begin"])])
end = tf.compat.v1.placeholder(
dtype=parameters["index_type"],
name="end",
shape=[len(parameters["input_shape"])])
shape=[len(parameters["end"])])
strides = None
if parameters["strides"] is not None:
strides = tf.compat.v1.placeholder(
dtype=parameters["index_type"],
name="strides",
shape=[len(parameters["input_shape"])])
shape=[len(parameters["strides"])])
tensors = [input_tensor, begin, end]
if strides is not None:
tensors.append(strides)
@ -141,7 +141,7 @@ def make_strided_slice_tests(options):
"begin_mask": [0],
"end_mask": [0],
"shrink_axis_mask": [1],
"constant_indices": [True],
"constant_indices": [True, False],
"fully_quantize": [False],
},
# 2-D