Fixing strided_slice if a dimension is 0

PiperOrigin-RevId: 304481522
Change-Id: I9ca0c130d4b02cdb0f5f1ca6529c9d28f62e6a7b
This commit is contained in:
A. Unique TensorFlower 2020-04-02 14:45:17 -07:00 committed by TensorFlower Gardener
parent d3642db310
commit 3107174868
2 changed files with 17 additions and 2 deletions

View File

@ -76,6 +76,10 @@ inline int StartForAxis(const tflite::StridedSliceParams& params,
const auto begin_mask = params.begin_mask;
const auto* start_indices = params.start_indices;
const auto* strides = params.strides;
const int axis_size = input_shape.Dims(axis);
if (axis_size == 0) {
return 0;
}
// Begin with the specified index.
int start = start_indices[axis];
@ -93,7 +97,6 @@ inline int StartForAxis(const tflite::StridedSliceParams& params,
}
// Handle negative indices
int axis_size = input_shape.Dims(axis);
if (start < 0) {
start += axis_size;
}
@ -116,6 +119,10 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
const auto shrink_axis_mask = params.shrink_axis_mask;
const auto* stop_indices = params.stop_indices;
const auto* strides = params.strides;
const int axis_size = input_shape.Dims(axis);
if (axis_size == 0) {
return 0;
}
// Begin with the specified index
const bool shrink_axis = shrink_axis_mask & (1 << axis);
@ -142,7 +149,6 @@ inline int StopForAxis(const tflite::StridedSliceParams& params,
}
// Handle negative indices
const int axis_size = input_shape.Dims(axis);
if (stop < 0) {
stop += axis_size;
}

View File

@ -97,6 +97,15 @@ TYPED_TEST(StridedSliceOpTest, UnsupportedArgs) {
}
#endif
TYPED_TEST(StridedSliceOpTest, In1DEmpty) {
StridedSliceOpModel<TypeParam> m({0}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetBegin({1});
m.SetEnd({3});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
}
TYPED_TEST(StridedSliceOpTest, In1D) {
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});