Fixing strided_slice if a dimension is 0
PiperOrigin-RevId: 304481522 Change-Id: I9ca0c130d4b02cdb0f5f1ca6529c9d28f62e6a7b
This commit is contained in:
parent
d3642db310
commit
3107174868
@ -76,6 +76,10 @@ inline int StartForAxis(const tflite::StridedSliceParams& params,
|
||||
const auto begin_mask = params.begin_mask;
|
||||
const auto* start_indices = params.start_indices;
|
||||
const auto* strides = params.strides;
|
||||
const int axis_size = input_shape.Dims(axis);
|
||||
if (axis_size == 0) {
|
||||
return 0;
|
||||
}
|
||||
// Begin with the specified index.
|
||||
int start = start_indices[axis];
|
||||
|
||||
@ -93,7 +97,6 @@ inline int StartForAxis(const tflite::StridedSliceParams& params,
|
||||
}
|
||||
|
||||
// Handle negative indices
|
||||
int axis_size = input_shape.Dims(axis);
|
||||
if (start < 0) {
|
||||
start += axis_size;
|
||||
}
|
||||
@ -116,6 +119,10 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
|
||||
const auto shrink_axis_mask = params.shrink_axis_mask;
|
||||
const auto* stop_indices = params.stop_indices;
|
||||
const auto* strides = params.strides;
|
||||
const int axis_size = input_shape.Dims(axis);
|
||||
if (axis_size == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Begin with the specified index
|
||||
const bool shrink_axis = shrink_axis_mask & (1 << axis);
|
||||
@ -142,7 +149,6 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
|
||||
}
|
||||
|
||||
// Handle negative indices
|
||||
const int axis_size = input_shape.Dims(axis);
|
||||
if (stop < 0) {
|
||||
stop += axis_size;
|
||||
}
|
||||
|
@ -97,6 +97,15 @@ TYPED_TEST(StridedSliceOpTest, UnsupportedArgs) {
|
||||
}
|
||||
#endif
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In1DEmpty) {
|
||||
StridedSliceOpModel<TypeParam> m({0}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({3});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In1D) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
|
Loading…
Reference in New Issue
Block a user