diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index 38769d1bc0b..5a4227c9971 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -1040,11 +1040,11 @@ struct SqueezeParams { struct StridedSliceParams { int8 start_indices_count; - int16 start_indices[4]; + int32 start_indices[4]; int8 stop_indices_count; - int16 stop_indices[4]; + int32 stop_indices[4]; int8 strides_count; - int16 strides[4]; + int32 strides[4]; int16 begin_mask; int16 ellipsis_mask; diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index b8a1b9ba704..7464c28793a 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -49,6 +49,9 @@ class StridedSliceOpModel : public SingleOpModel { void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetInput(const std::vector data) { + PopulateTensor(input_, data); + } void SetBegin(std::initializer_list data) { PopulateTensor(begin_, data); } @@ -98,6 +101,21 @@ TEST(StridedSliceOpTest, In1D) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); } +TEST(StridedSliceOpTest, In1D_Int32End) { + StridedSliceOpModel<> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + std::vector values; + for (int i = 0; i < 32768; i++) { + values.push_back(i); + } + m.SetInput(values); + m.SetBegin({0}); + m.SetEnd({32768}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({32768})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(values)); +} + TEST(StridedSliceOpTest, In1D_EmptyOutput) { StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); @@ -567,8 +585,8 @@ TEST(StridedSliceOpTest, RunTwice) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, - 0, 0, 1); + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, + 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 3, 2});