Fix overflow bug in strided_slice kernel implementation.
strided_slice would overflow for end and start slices larger than int16. This changes the StridedSliceParams start_indices and end_indices to int32 values. This lines up better with the documented types of begin/end. PiperOrigin-RevId: 282423014 Change-Id: I444a2f7fd094aeafc5a8126ec3785a4ed928f92f
This commit is contained in:
parent
805e659f8e
commit
7960506932
@ -1040,11 +1040,11 @@ struct SqueezeParams {
|
|||||||
|
|
||||||
struct StridedSliceParams {
|
struct StridedSliceParams {
|
||||||
int8 start_indices_count;
|
int8 start_indices_count;
|
||||||
int16 start_indices[4];
|
int32 start_indices[4];
|
||||||
int8 stop_indices_count;
|
int8 stop_indices_count;
|
||||||
int16 stop_indices[4];
|
int32 stop_indices[4];
|
||||||
int8 strides_count;
|
int8 strides_count;
|
||||||
int16 strides[4];
|
int32 strides[4];
|
||||||
|
|
||||||
int16 begin_mask;
|
int16 begin_mask;
|
||||||
int16 ellipsis_mask;
|
int16 ellipsis_mask;
|
||||||
|
@ -49,6 +49,9 @@ class StridedSliceOpModel : public SingleOpModel {
|
|||||||
void SetInput(std::initializer_list<input_type> data) {
|
void SetInput(std::initializer_list<input_type> data) {
|
||||||
PopulateTensor<input_type>(input_, data);
|
PopulateTensor<input_type>(input_, data);
|
||||||
}
|
}
|
||||||
|
void SetInput(const std::vector<input_type> data) {
|
||||||
|
PopulateTensor<input_type>(input_, data);
|
||||||
|
}
|
||||||
void SetBegin(std::initializer_list<int32_t> data) {
|
void SetBegin(std::initializer_list<int32_t> data) {
|
||||||
PopulateTensor<int32_t>(begin_, data);
|
PopulateTensor<int32_t>(begin_, data);
|
||||||
}
|
}
|
||||||
@ -98,6 +101,21 @@ TEST(StridedSliceOpTest, In1D) {
|
|||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StridedSliceOpTest, In1D_Int32End) {
|
||||||
|
StridedSliceOpModel<> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||||
|
std::vector<float> values;
|
||||||
|
for (int i = 0; i < 32768; i++) {
|
||||||
|
values.push_back(i);
|
||||||
|
}
|
||||||
|
m.SetInput(values);
|
||||||
|
m.SetBegin({0});
|
||||||
|
m.SetEnd({32768});
|
||||||
|
m.SetStrides({1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({32768}));
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray(values));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
|
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
|
||||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||||
m.SetInput({1, 2, 3, 4});
|
m.SetInput({1, 2, 3, 4});
|
||||||
@ -567,8 +585,8 @@ TEST(StridedSliceOpTest, RunTwice) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
||||||
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
|
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0,
|
||||||
0, 0, 1);
|
0, 0, 0, 1);
|
||||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
m.SetBegin({0, 0, 0});
|
m.SetBegin({0, 0, 0});
|
||||||
m.SetEnd({1, 3, 2});
|
m.SetEnd({1, 3, 2});
|
||||||
|
Loading…
Reference in New Issue
Block a user