Fix overflow bug in strided_slice kernel implementation.
strided_slice would overflow for end and start slices larger than int16. This changes the StridedSliceParams start_indices and end_indices to int32 values. This lines up better with the documented types of begin/end. PiperOrigin-RevId: 282423014 Change-Id: I444a2f7fd094aeafc5a8126ec3785a4ed928f92f
This commit is contained in:
parent
805e659f8e
commit
7960506932
@ -1040,11 +1040,11 @@ struct SqueezeParams {
|
||||
|
||||
struct StridedSliceParams {
|
||||
int8 start_indices_count;
|
||||
int16 start_indices[4];
|
||||
int32 start_indices[4];
|
||||
int8 stop_indices_count;
|
||||
int16 stop_indices[4];
|
||||
int32 stop_indices[4];
|
||||
int8 strides_count;
|
||||
int16 strides[4];
|
||||
int32 strides[4];
|
||||
|
||||
int16 begin_mask;
|
||||
int16 ellipsis_mask;
|
||||
|
@ -49,6 +49,9 @@ class StridedSliceOpModel : public SingleOpModel {
|
||||
void SetInput(std::initializer_list<input_type> data) {
|
||||
PopulateTensor<input_type>(input_, data);
|
||||
}
|
||||
void SetInput(const std::vector<input_type> data) {
|
||||
PopulateTensor<input_type>(input_, data);
|
||||
}
|
||||
void SetBegin(std::initializer_list<int32_t> data) {
|
||||
PopulateTensor<int32_t>(begin_, data);
|
||||
}
|
||||
@ -98,6 +101,21 @@ TEST(StridedSliceOpTest, In1D) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_Int32End) {
|
||||
StridedSliceOpModel<> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
std::vector<float> values;
|
||||
for (int i = 0; i < 32768; i++) {
|
||||
values.push_back(i);
|
||||
}
|
||||
m.SetInput(values);
|
||||
m.SetBegin({0});
|
||||
m.SetEnd({32768});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({32768}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(values));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
@ -567,8 +585,8 @@ TEST(StridedSliceOpTest, RunTwice) {
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
||||
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
|
||||
0, 0, 1);
|
||||
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0,
|
||||
0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 3, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user