Fix error when end_mask and shrink_mask are set at the same axis

PiperOrigin-RevId: 350255382
Change-Id: I1ac180e02a22b62570fe4491fc1e08c6e8fda1de
This commit is contained in:
Thai Nguyen 2021-01-05 17:44:46 -08:00 committed by TensorFlower Gardener
parent 91a9569424
commit 85504f9555
3 changed files with 17 additions and 4 deletions

View File

@ -140,7 +140,7 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
// already been adjusted for negative indices.
if (shrink_axis) {
stop = start_for_axis + 1;
return start_for_axis + 1;
}
// end_mask override

View File

@ -745,5 +745,17 @@ TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) {
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
}
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis_Endmask_AtSameAxis) {
StridedSliceOpModel<TypeParam> m({2, 2}, {2}, {2}, {2}, 1, 1, 0, 0, 1);
m.SetInput({0, 1, 2, 3});
m.SetBegin({0, -1});
m.SetEnd({0, 0});
m.SetStrides({1, -1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
}
} // namespace
} // namespace tflite

View File

@ -63,7 +63,8 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
end,
strides,
begin_mask=parameters["begin_mask"],
end_mask=parameters["end_mask"])
end_mask=parameters["end_mask"],
shrink_axis_mask=parameters["shrink_axis_mask"])
return tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
@ -241,12 +242,12 @@ def make_strided_slice_tests(options):
"strides": [[2, 1, 3, 1]],
"begin_mask": [8],
"end_mask": [3],
"shrink_axis_mask": [None, -1],
"shrink_axis_mask": [None],
"constant_indices": [True, False],
"fully_quantize": [False],
}
]
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=29)
@register_make_test_function()