diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index 5183ab4b062..21fdceac6d7 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -360,6 +360,7 @@ SelectOpTest/.+,29 -SliceOpTest/SliceOpTest/SliceString/.+ -SliceOpTest/SliceOpTest/SliceInt64/.+ -SliceOpTest/SliceOpTest/SliceBool/.+ +-SliceOpTest/SliceOpTest/SliceInt16/.+ # Only constant tensors SliceOpTest/SliceOpTest/.+/0,29 @@ -409,6 +410,7 @@ TopKV2OpTest/TopKV2OpTest/.+/0,29 -TransposeTest/5DDividedIntoTwo2Ds.* -TransposeTest/Complex5DTest.* -TransposeTest/.+DynamicTensor +-TransposeTest/TestRefOps4DInt16 TransposeTest/.+ # transpose_conv_test diff --git a/tensorflow/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc index bb123302995..3f6eb73e843 100644 --- a/tensorflow/lite/kernels/slice.cc +++ b/tensorflow/lite/kernels/slice.cc @@ -214,6 +214,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: TF_LITE_SLICE(int8_t, kernel_type); break; + case kTfLiteInt16: + TF_LITE_SLICE(int16_t, kernel_type); + break; case kTfLiteUInt8: TF_LITE_SLICE(uint8_t, kernel_type); break; diff --git a/tensorflow/lite/kernels/slice_test.cc b/tensorflow/lite/kernels/slice_test.cc index 1e61e1e68aa..0379eda9a01 100644 --- a/tensorflow/lite/kernels/slice_test.cc +++ b/tensorflow/lite/kernels/slice_test.cc @@ -226,6 +226,16 @@ TEST_P(SliceOpTest, SliceInt8) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); } +TEST_P(SliceOpTest, SliceInt16) { + SliceOpModel m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4}, + {2, 1, -1, 1}, TensorType_INT32, + TensorType_INT16, GetParam()); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); +} + TEST_P(SliceOpTest, SliceString) { SliceOpModel m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4}, {2, 1, -1, 1}, TensorType_INT32, diff --git a/tensorflow/lite/kernels/transpose.cc b/tensorflow/lite/kernels/transpose.cc index 3a6d1b1f1ed..f5ddcb2b362 100644 --- a/tensorflow/lite/kernels/transpose.cc +++ b/tensorflow/lite/kernels/transpose.cc @@ -130,6 +130,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_TRANSPOSE(reference_ops, int8_t); } break; + case kTfLiteInt16: + TF_LITE_TRANSPOSE(reference_ops, int16_t); + break; case kTfLiteInt64: TF_LITE_TRANSPOSE(reference_ops, int64_t); break; diff --git a/tensorflow/lite/kernels/transpose_test.cc b/tensorflow/lite/kernels/transpose_test.cc index a88abec7161..449afe8dec7 100644 --- a/tensorflow/lite/kernels/transpose_test.cc +++ b/tensorflow/lite/kernels/transpose_test.cc @@ -180,13 +180,14 @@ TEST(TransposeTest, TestRefOps4D) { ASSERT_EQ(out, ref); } -TEST(TransposeTest, TestRefOps4DInt8) { - std::vector out; +template +void TransposeTestTestRefOps4D() { + std::vector out; // Basic 4d. RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out); ASSERT_EQ( out, - std::vector( + std::vector( {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, @@ -197,11 +198,19 @@ TEST(TransposeTest, TestRefOps4DInt8) { 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119})); RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out); // Basic identity. - std::vector ref(out.size()); + std::vector ref(out.size()); for (int k = 0; k < ref.size(); k++) ref[k] = k; ASSERT_EQ(out, ref); } +TEST(TransposeTest, TestRefOps4DInt8) { + TransposeTestTestRefOps4D(); +} + +TEST(TransposeTest, TestRefOps4DInt16) { + TransposeTestTestRefOps4D(); +} + class TransposeOpModel : public SingleOpModel { public: void SetInput(std::initializer_list data) {