Merge pull request #36070 from psunn:int16_strided_slice
PiperOrigin-RevId: 292252738 Change-Id: Id6ab2e3dc9928fd643a1c852af5cd8ce2a17e699
This commit is contained in:
commit
fa70bf027f
@ -203,6 +203,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int16_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, bool);
|
||||
|
@ -23,8 +23,7 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename input_type = float,
|
||||
TensorType tensor_input_type = TensorType_FLOAT32>
|
||||
template <typename input_type>
|
||||
class StridedSliceOpModel : public SingleOpModel {
|
||||
public:
|
||||
StridedSliceOpModel(std::initializer_list<int> input_shape,
|
||||
@ -33,11 +32,11 @@ class StridedSliceOpModel : public SingleOpModel {
|
||||
std::initializer_list<int> strides_shape, int begin_mask,
|
||||
int end_mask, int ellipsis_mask, int new_axis_mask,
|
||||
int shrink_axis_mask) {
|
||||
input_ = AddInput(tensor_input_type);
|
||||
input_ = AddInput(GetTensorType<input_type>());
|
||||
begin_ = AddInput(TensorType_INT32);
|
||||
end_ = AddInput(TensorType_INT32);
|
||||
strides_ = AddInput(TensorType_INT32);
|
||||
output_ = AddOutput(tensor_input_type);
|
||||
output_ = AddOutput(GetTensorType<input_type>());
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
|
||||
CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
|
||||
@ -75,23 +74,31 @@ class StridedSliceOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class StridedSliceOpTest : public ::testing::Test {};
|
||||
|
||||
using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>;
|
||||
TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes);
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
TEST(StridedSliceOpTest, UnsupportedInputSize) {
|
||||
EXPECT_DEATH(
|
||||
StridedSliceOpModel<>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
|
||||
"StridedSlice op only supports 1D-4D input arrays.");
|
||||
TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) {
|
||||
EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0,
|
||||
0, 0, 0, 0),
|
||||
"StridedSlice op only supports 1D-4D input arrays.");
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, UnssupportedArgs) {
|
||||
EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
|
||||
"ellipsis_mask is not implemented yet.");
|
||||
EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
|
||||
"new_axis_mask is not implemented yet.");
|
||||
TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) {
|
||||
EXPECT_DEATH(
|
||||
StridedSliceOpModel<TypeParam>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
|
||||
"ellipsis_mask is not implemented yet.");
|
||||
EXPECT_DEATH(
|
||||
StridedSliceOpModel<TypeParam>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
|
||||
"new_axis_mask is not implemented yet.");
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(StridedSliceOpTest, In1D) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({3});
|
||||
@ -101,9 +108,9 @@ TEST(StridedSliceOpTest, In1D) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_Int32End) {
|
||||
StridedSliceOpModel<> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
std::vector<float> values;
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_Int32End) {
|
||||
StridedSliceOpModel<TypeParam> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
std::vector<TypeParam> values;
|
||||
for (int i = 0; i < 32768; i++) {
|
||||
values.push_back(i);
|
||||
}
|
||||
@ -116,8 +123,8 @@ TEST(StridedSliceOpTest, In1D_Int32End) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(values));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_EmptyOutput) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({10});
|
||||
m.SetEnd({3});
|
||||
@ -126,8 +133,8 @@ TEST(StridedSliceOpTest, In1D_EmptyOutput) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_NegativeBegin) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_NegativeBegin) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({-3});
|
||||
m.SetEnd({3});
|
||||
@ -137,8 +144,8 @@ TEST(StridedSliceOpTest, In1D_NegativeBegin) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({-5});
|
||||
m.SetEnd({3});
|
||||
@ -148,8 +155,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_NegativeEnd) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_NegativeEnd) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({-2});
|
||||
@ -159,8 +166,8 @@ TEST(StridedSliceOpTest, In1D_NegativeEnd) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({-3});
|
||||
m.SetEnd({5});
|
||||
@ -170,8 +177,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_BeginMask) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_BeginMask) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({3});
|
||||
@ -181,8 +188,8 @@ TEST(StridedSliceOpTest, In1D_BeginMask) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({-2});
|
||||
m.SetEnd({-3});
|
||||
@ -192,8 +199,8 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({5});
|
||||
m.SetEnd({2});
|
||||
@ -203,8 +210,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({2});
|
||||
m.SetEnd({-4});
|
||||
@ -214,8 +221,8 @@ TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({-3});
|
||||
m.SetEnd({-5});
|
||||
@ -225,8 +232,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_EndMask) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_EndMask) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({3});
|
||||
@ -236,8 +243,8 @@ TEST(StridedSliceOpTest, In1D_EndMask) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_NegStride) {
|
||||
StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_NegStride) {
|
||||
StridedSliceOpModel<TypeParam> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3});
|
||||
m.SetBegin({-1});
|
||||
m.SetEnd({-4});
|
||||
@ -247,8 +254,8 @@ TEST(StridedSliceOpTest, In1D_NegStride) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
|
||||
StridedSliceOpModel<> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
|
||||
StridedSliceOpModel<TypeParam> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2});
|
||||
m.SetBegin({0});
|
||||
m.SetEnd({2});
|
||||
@ -258,8 +265,8 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_OddLenStride2) {
|
||||
StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_OddLenStride2) {
|
||||
StridedSliceOpModel<TypeParam> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3});
|
||||
m.SetBegin({0});
|
||||
m.SetEnd({3});
|
||||
@ -269,8 +276,8 @@ TEST(StridedSliceOpTest, In1D_OddLenStride2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_Identity) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_Identity) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({0, 0});
|
||||
m.SetEnd({2, 3});
|
||||
@ -280,8 +287,8 @@ TEST(StridedSliceOpTest, In2D_Identity) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, 0});
|
||||
m.SetEnd({2, 2});
|
||||
@ -291,8 +298,8 @@ TEST(StridedSliceOpTest, In2D) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_Stride2) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_Stride2) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({0, 0});
|
||||
m.SetEnd({2, 3});
|
||||
@ -302,8 +309,8 @@ TEST(StridedSliceOpTest, In2D_Stride2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_NegStride) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_NegStride) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, -1});
|
||||
m.SetEnd({2, -4});
|
||||
@ -313,8 +320,8 @@ TEST(StridedSliceOpTest, In2D_NegStride) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_BeginMask) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_BeginMask) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, 0});
|
||||
m.SetEnd({2, 2});
|
||||
@ -324,8 +331,8 @@ TEST(StridedSliceOpTest, In2D_BeginMask) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_EndMask) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_EndMask) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, 0});
|
||||
m.SetEnd({2, 2});
|
||||
@ -335,8 +342,8 @@ TEST(StridedSliceOpTest, In2D_EndMask) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, -2});
|
||||
m.SetEnd({2, -4});
|
||||
@ -346,8 +353,8 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, -2});
|
||||
m.SetEnd({2, -3});
|
||||
@ -357,8 +364,8 @@ TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_Identity) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_Identity) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({2, 3, 2});
|
||||
@ -369,8 +376,8 @@ TEST(StridedSliceOpTest, In3D_Identity) {
|
||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_NegStride) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_NegStride) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({-1, -1, -1});
|
||||
m.SetEnd({-3, -4, -3});
|
||||
@ -381,8 +388,8 @@ TEST(StridedSliceOpTest, In3D_NegStride) {
|
||||
ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_Strided2) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_Strided2) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({2, 3, 2});
|
||||
@ -392,8 +399,8 @@ TEST(StridedSliceOpTest, In3D_Strided2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({2});
|
||||
@ -403,9 +410,9 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) {
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) {
|
||||
// This is equivalent to tf.range(4)[-1].
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({0, 1, 2, 3});
|
||||
m.SetBegin({-1});
|
||||
m.SetEnd({0});
|
||||
@ -416,9 +423,9 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) {
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) {
|
||||
// This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1].
|
||||
StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
|
||||
StridedSliceOpModel<TypeParam> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
|
||||
m.SetInput({0, 1, 2, 3});
|
||||
m.SetBegin({-2, -1});
|
||||
m.SetEnd({-1, 0});
|
||||
@ -429,9 +436,9 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) {
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) {
|
||||
// This is equivalent to tf.range(4)[:, tf.newaxis][:, -1].
|
||||
StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2);
|
||||
StridedSliceOpModel<TypeParam> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2);
|
||||
m.SetInput({0, 1, 2, 3});
|
||||
m.SetBegin({0, -1});
|
||||
m.SetEnd({0, 0});
|
||||
@ -442,8 +449,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
|
||||
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
|
||||
TYPED_TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
|
||||
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({1});
|
||||
@ -453,8 +460,8 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({0, 0});
|
||||
m.SetEnd({1, 3});
|
||||
@ -464,8 +471,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({0, 0});
|
||||
m.SetEnd({2, 1});
|
||||
@ -475,8 +482,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
|
||||
TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({0, 0});
|
||||
m.SetEnd({1, 1});
|
||||
@ -486,8 +493,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 3, 2});
|
||||
@ -497,8 +504,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({2, 1, 2});
|
||||
@ -508,8 +515,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 1, 2});
|
||||
@ -519,8 +526,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({2, 3, 1});
|
||||
@ -530,8 +537,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 3, 1});
|
||||
@ -541,8 +548,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({2, 1, 1});
|
||||
@ -552,8 +559,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
|
||||
StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 1, 1});
|
||||
@ -564,8 +571,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
|
||||
}
|
||||
|
||||
// This tests catches a very subtle bug that was fixed by cl/188403234.
|
||||
TEST(StridedSliceOpTest, RunTwice) {
|
||||
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
|
||||
TYPED_TEST(StridedSliceOpTest, RunTwice) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
|
||||
|
||||
auto setup_inputs = [&m]() {
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
@ -584,9 +591,8 @@ TEST(StridedSliceOpTest, RunTwice) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
||||
StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0,
|
||||
0, 0, 0, 1);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 3, 2});
|
||||
@ -596,9 +602,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
|
||||
StridedSliceOpModel<int8_t, TensorType_INT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
|
||||
0, 0, 1);
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 3, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user