Fix error when end_mask and shrink_mask are set at the same axis
PiperOrigin-RevId: 350255382 Change-Id: I1ac180e02a22b62570fe4491fc1e08c6e8fda1de
This commit is contained in:
parent
91a9569424
commit
85504f9555
@ -140,7 +140,7 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
|
|||||||
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
|
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
|
||||||
// already been adjusted for negative indices.
|
// already been adjusted for negative indices.
|
||||||
if (shrink_axis) {
|
if (shrink_axis) {
|
||||||
stop = start_for_axis + 1;
|
return start_for_axis + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// end_mask override
|
// end_mask override
|
||||||
|
@ -745,5 +745,17 @@ TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) {
|
|||||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
|
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis_Endmask_AtSameAxis) {
|
||||||
|
StridedSliceOpModel<TypeParam> m({2, 2}, {2}, {2}, {2}, 1, 1, 0, 0, 1);
|
||||||
|
m.SetInput({0, 1, 2, 3});
|
||||||
|
m.SetBegin({0, -1});
|
||||||
|
m.SetEnd({0, 0});
|
||||||
|
m.SetStrides({1, -1});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -63,7 +63,8 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
|
|||||||
end,
|
end,
|
||||||
strides,
|
strides,
|
||||||
begin_mask=parameters["begin_mask"],
|
begin_mask=parameters["begin_mask"],
|
||||||
end_mask=parameters["end_mask"])
|
end_mask=parameters["end_mask"],
|
||||||
|
shrink_axis_mask=parameters["shrink_axis_mask"])
|
||||||
return tensors, [out]
|
return tensors, [out]
|
||||||
|
|
||||||
def build_inputs(parameters, sess, inputs, outputs):
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
@ -241,12 +242,12 @@ def make_strided_slice_tests(options):
|
|||||||
"strides": [[2, 1, 3, 1]],
|
"strides": [[2, 1, 3, 1]],
|
||||||
"begin_mask": [8],
|
"begin_mask": [8],
|
||||||
"end_mask": [3],
|
"end_mask": [3],
|
||||||
"shrink_axis_mask": [None, -1],
|
"shrink_axis_mask": [None],
|
||||||
"constant_indices": [True, False],
|
"constant_indices": [True, False],
|
||||||
"fully_quantize": [False],
|
"fully_quantize": [False],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)
|
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=29)
|
||||||
|
|
||||||
|
|
||||||
@register_make_test_function()
|
@register_make_test_function()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user