diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index fb9efcc9db5..e0e226e2ea5 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -203,6 +203,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_STRIDED_SLICE(reference_ops, int8_t); } break; + case kTfLiteInt16: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int16_t); + } + break; case kTfLiteBool: if (kernel_type == kReference) { TF_LITE_STRIDED_SLICE(reference_ops, bool); diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 7464c28793a..83093a09eed 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -23,8 +23,7 @@ namespace { using ::testing::ElementsAreArray; -template +template class StridedSliceOpModel : public SingleOpModel { public: StridedSliceOpModel(std::initializer_list input_shape, @@ -33,11 +32,11 @@ class StridedSliceOpModel : public SingleOpModel { std::initializer_list strides_shape, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask) { - input_ = AddInput(tensor_input_type); + input_ = AddInput(GetTensorType()); begin_ = AddInput(TensorType_INT32); end_ = AddInput(TensorType_INT32); strides_ = AddInput(TensorType_INT32); - output_ = AddOutput(tensor_input_type); + output_ = AddOutput(GetTensorType()); SetBuiltinOp( BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, @@ -75,23 +74,31 @@ class StridedSliceOpModel : public SingleOpModel { int output_; }; +template +class StridedSliceOpTest : public ::testing::Test {}; + +using DataTypes = ::testing::Types; +TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes); + #ifdef GTEST_HAS_DEATH_TEST -TEST(StridedSliceOpTest, UnsupportedInputSize) { - EXPECT_DEATH( - StridedSliceOpModel<>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), - "StridedSlice op only supports 1D-4D input arrays."); +TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) { + EXPECT_DEATH(StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, + 0, 0, 0, 0), + "StridedSlice op only supports 1D-4D input arrays."); } -TEST(StridedSliceOpTest, UnssupportedArgs) { - EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), - "ellipsis_mask is not implemented yet."); - EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), - "new_axis_mask is not implemented yet."); +TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) { + EXPECT_DEATH( + StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), + "ellipsis_mask is not implemented yet."); + EXPECT_DEATH( + StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), + "new_axis_mask is not implemented yet."); } #endif -TEST(StridedSliceOpTest, In1D) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -101,9 +108,9 @@ TEST(StridedSliceOpTest, In1D) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); } -TEST(StridedSliceOpTest, In1D_Int32End) { - StridedSliceOpModel<> m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0); - std::vector values; +TYPED_TEST(StridedSliceOpTest, In1D_Int32End) { + StridedSliceOpModel m({32768}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + std::vector values; for (int i = 0; i < 32768; i++) { values.push_back(i); } @@ -116,8 +123,8 @@ TEST(StridedSliceOpTest, In1D_Int32End) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(values)); } -TEST(StridedSliceOpTest, In1D_EmptyOutput) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_EmptyOutput) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({10}); m.SetEnd({3}); @@ -126,8 +133,8 @@ TEST(StridedSliceOpTest, In1D_EmptyOutput) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0})); } -TEST(StridedSliceOpTest, In1D_NegativeBegin) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_NegativeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({3}); @@ -137,8 +144,8 @@ TEST(StridedSliceOpTest, In1D_NegativeBegin) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); } -TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-5}); m.SetEnd({3}); @@ -148,8 +155,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); } -TEST(StridedSliceOpTest, In1D_NegativeEnd) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_NegativeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({-2}); @@ -159,8 +166,8 @@ TEST(StridedSliceOpTest, In1D_NegativeEnd) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); } -TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({5}); @@ -170,8 +177,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); } -TEST(StridedSliceOpTest, In1D_BeginMask) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_BeginMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -181,8 +188,8 @@ TEST(StridedSliceOpTest, In1D_BeginMask) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); } -TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-2}); m.SetEnd({-3}); @@ -192,8 +199,8 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); } -TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({5}); m.SetEnd({2}); @@ -203,8 +210,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({4})); } -TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({2}); m.SetEnd({-4}); @@ -214,8 +221,8 @@ TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2})); } -TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({-3}); m.SetEnd({-5}); @@ -225,8 +232,8 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1})); } -TEST(StridedSliceOpTest, In1D_EndMask) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_EndMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({3}); @@ -236,8 +243,8 @@ TEST(StridedSliceOpTest, In1D_EndMask) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); } -TEST(StridedSliceOpTest, In1D_NegStride) { - StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_NegStride) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); m.SetBegin({-1}); m.SetEnd({-4}); @@ -247,8 +254,8 @@ TEST(StridedSliceOpTest, In1D_NegStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1})); } -TEST(StridedSliceOpTest, In1D_EvenLenStride2) { - StridedSliceOpModel<> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_EvenLenStride2) { + StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2}); m.SetBegin({0}); m.SetEnd({2}); @@ -258,8 +265,8 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } -TEST(StridedSliceOpTest, In1D_OddLenStride2) { - StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In1D_OddLenStride2) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); m.SetBegin({0}); m.SetEnd({3}); @@ -269,8 +276,8 @@ TEST(StridedSliceOpTest, In1D_OddLenStride2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); } -TEST(StridedSliceOpTest, In2D_Identity) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_Identity) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -280,8 +287,8 @@ TEST(StridedSliceOpTest, In2D_Identity) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(StridedSliceOpTest, In2D) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -291,8 +298,8 @@ TEST(StridedSliceOpTest, In2D) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); } -TEST(StridedSliceOpTest, In2D_Stride2) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_Stride2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 3}); @@ -302,8 +309,8 @@ TEST(StridedSliceOpTest, In2D_Stride2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); } -TEST(StridedSliceOpTest, In2D_NegStride) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_NegStride) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -1}); m.SetEnd({2, -4}); @@ -313,8 +320,8 @@ TEST(StridedSliceOpTest, In2D_NegStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); } -TEST(StridedSliceOpTest, In2D_BeginMask) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_BeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -324,8 +331,8 @@ TEST(StridedSliceOpTest, In2D_BeginMask) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); } -TEST(StridedSliceOpTest, In2D_EndMask) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_EndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, 0}); m.SetEnd({2, 2}); @@ -335,8 +342,8 @@ TEST(StridedSliceOpTest, In2D_EndMask) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6})); } -TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -2}); m.SetEnd({2, -4}); @@ -346,8 +353,8 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); } -TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({1, -2}); m.SetEnd({2, -3}); @@ -357,8 +364,8 @@ TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4})); } -TEST(StridedSliceOpTest, In3D_Identity) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In3D_Identity) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -369,8 +376,8 @@ TEST(StridedSliceOpTest, In3D_Identity) { ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); } -TEST(StridedSliceOpTest, In3D_NegStride) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In3D_NegStride) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({-1, -1, -1}); m.SetEnd({-3, -4, -3}); @@ -381,8 +388,8 @@ TEST(StridedSliceOpTest, In3D_NegStride) { ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); } -TEST(StridedSliceOpTest, In3D_Strided2) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, In3D_Strided2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 2}); @@ -392,8 +399,8 @@ TEST(StridedSliceOpTest, In3D_Strided2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); } -TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); +TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({2}); @@ -403,9 +410,9 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); } -TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { +TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { // This is equivalent to tf.range(4)[-1]. - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); m.SetInput({0, 1, 2, 3}); m.SetBegin({-1}); m.SetEnd({0}); @@ -416,9 +423,9 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); } -TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { +TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. - StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + StridedSliceOpModel m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); m.SetInput({0, 1, 2, 3}); m.SetBegin({-2, -1}); m.SetEnd({-1, 0}); @@ -429,9 +436,9 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); } -TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { +TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { // This is equivalent to tf.range(4)[:, tf.newaxis][:, -1]. - StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2); + StridedSliceOpModel m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2); m.SetInput({0, 1, 2, 3}); m.SetBegin({0, -1}); m.SetEnd({0, 0}); @@ -442,8 +449,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3})); } -TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); +TYPED_TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); m.SetBegin({1}); m.SetEnd({1}); @@ -453,8 +460,8 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } -TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); +TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({1, 3}); @@ -464,8 +471,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); } -TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); +TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({2, 1}); @@ -475,8 +482,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4})); } -TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); +TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); m.SetInput({1, 2, 3, 4, 5, 6}); m.SetBegin({0, 0}); m.SetEnd({1, 1}); @@ -486,8 +493,8 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 3, 2}); @@ -497,8 +504,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 1, 2}); @@ -508,8 +515,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 1, 2}); @@ -519,8 +526,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 3, 1}); @@ -530,8 +537,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 3, 1}); @@ -541,8 +548,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({2, 1, 1}); @@ -552,8 +559,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { - StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 1, 1}); @@ -564,8 +571,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { } // This tests catches a very subtle bug that was fixed by cl/188403234. -TEST(StridedSliceOpTest, RunTwice) { - StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); +TYPED_TEST(StridedSliceOpTest, RunTwice) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); auto setup_inputs = [&m]() { m.SetInput({1, 2, 3, 4, 5, 6}); @@ -584,9 +591,8 @@ TEST(StridedSliceOpTest, RunTwice) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, - 0, 0, 0, 1); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 3, 2}); @@ -596,9 +602,8 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, - 0, 0, 1); +TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); m.SetEnd({1, 3, 2});