Fix overflow bug in strided_slice kernel implementation.

strided_slice would overflow for end and start slices larger than int16. This changes the StridedSliceParams start_indices and end_indices to int32 values. This lines up better with the documented types of begin/end.

PiperOrigin-RevId: 282423014
Change-Id: I444a2f7fd094aeafc5a8126ec3785a4ed928f92f
This commit is contained in:
A. Unique TensorFlower 2019-11-25 13:36:03 -08:00 committed by TensorFlower Gardener
parent 805e659f8e
commit 7960506932
2 changed files with 23 additions and 5 deletions

View File

@ -1040,11 +1040,11 @@ struct SqueezeParams {
struct StridedSliceParams {
int8 start_indices_count;
int16 start_indices[4];
int32 start_indices[4];
int8 stop_indices_count;
int16 stop_indices[4];
int32 stop_indices[4];
int8 strides_count;
int16 strides[4];
int32 strides[4];
int16 begin_mask;
int16 ellipsis_mask;

View File

@ -49,6 +49,9 @@ class StridedSliceOpModel : public SingleOpModel {
void SetInput(std::initializer_list<input_type> data) {
PopulateTensor<input_type>(input_, data);
}
void SetInput(const std::vector<input_type> data) {
PopulateTensor<input_type>(input_, data);
}
void SetBegin(std::initializer_list<int32_t> data) {
PopulateTensor<int32_t>(begin_, data);
}
@ -98,6 +101,21 @@ TEST(StridedSliceOpTest, In1D) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
}
TEST(StridedSliceOpTest, In1D_Int32End) {
StridedSliceOpModel<> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
std::vector<float> values;
for (int i = 0; i < 32768; i++) {
values.push_back(i);
}
m.SetInput(values);
m.SetBegin({0});
m.SetEnd({32768});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({32768}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray(values));
}
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
@ -567,8 +585,8 @@ TEST(StridedSliceOpTest, RunTwice) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
0, 0, 1);
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0,
0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({1, 3, 2});