Fix error when end_mask and shrink_mask are set at the same axis
PiperOrigin-RevId: 350255382 Change-Id: I1ac180e02a22b62570fe4491fc1e08c6e8fda1de
This commit is contained in:
parent
91a9569424
commit
85504f9555
@ -140,7 +140,7 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
|
||||
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
|
||||
// already been adjusted for negative indices.
|
||||
if (shrink_axis) {
|
||||
stop = start_for_axis + 1;
|
||||
return start_for_axis + 1;
|
||||
}
|
||||
|
||||
// end_mask override
|
||||
|
@ -745,5 +745,17 @@ TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) {
|
||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis_Endmask_AtSameAxis) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 2}, {2}, {2}, {2}, 1, 1, 0, 0, 1);
|
||||
m.SetInput({0, 1, 2, 3});
|
||||
m.SetBegin({0, -1});
|
||||
m.SetEnd({0, 0});
|
||||
m.SetStrides({1, -1});
|
||||
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -63,7 +63,8 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
|
||||
end,
|
||||
strides,
|
||||
begin_mask=parameters["begin_mask"],
|
||||
end_mask=parameters["end_mask"])
|
||||
end_mask=parameters["end_mask"],
|
||||
shrink_axis_mask=parameters["shrink_axis_mask"])
|
||||
return tensors, [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
@ -241,12 +242,12 @@ def make_strided_slice_tests(options):
|
||||
"strides": [[2, 1, 3, 1]],
|
||||
"begin_mask": [8],
|
||||
"end_mask": [3],
|
||||
"shrink_axis_mask": [None, -1],
|
||||
"shrink_axis_mask": [None],
|
||||
"constant_indices": [True, False],
|
||||
"fully_quantize": [False],
|
||||
}
|
||||
]
|
||||
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)
|
||||
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=29)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
|
Loading…
x
Reference in New Issue
Block a user