diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index 56c1895ca4e..3867c95331a 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -331,6 +331,8 @@ ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+AlignCorners.*/0,30 ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+HalfPixelCenters.*/0,30 // Only models with constant size tensor are accelerated ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29 +// 16-bit tests are not supported +-ResizeNearestNeighborOpTest/.*Int16.* # select_test -SelectOpTest/SelectBool diff --git a/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc index f8a455e7451..debeb36e48f 100644 --- a/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc @@ -91,6 +91,17 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3) { output_shape, output_data); } +TEST(ResizeNearestNeighborReference, Test2x2To3x3Int16) { + RuntimeShape input_shape = {1, 2, 2, 1}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; + RuntimeShape output_shape = {1, 3, 3, 1}; + std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; + + TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, + output_shape, output_data); +} + TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 333ffc12d7e..391ceca83b2 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -123,7 +123,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, Register_RESIZE_NEAREST_NEIGHBOR(), /* min_version = */ 1, - /* max_version = */ 3); + /* max_version = */ 4); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(), /* min_version = */ 1, diff --git a/tensorflow/lite/kernels/resize_nearest_neighbor.cc b/tensorflow/lite/kernels/resize_nearest_neighbor.cc index 13c54c4f906..bc4eb85afeb 100644 --- a/tensorflow/lite/kernels/resize_nearest_neighbor.cc +++ b/tensorflow/lite/kernels/resize_nearest_neighbor.cc @@ -121,6 +121,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(size), GetTensorData(size), GetTensorShape(output), GetTensorData(output)); + } else if (output->type == kTfLiteInt16) { + reference_ops::ResizeNearestNeighbor( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(size), GetTensorData(size), + GetTensorShape(output), GetTensorData(output)); } else { TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float, uint8 or int8.", diff --git a/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc index b22ad48afb9..5ceb1b6ea83 100644 --- a/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc @@ -106,6 +106,14 @@ TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeInt8) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-3, -3, 6}))); } +TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeInt16) { + ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 1, 2, 1}}, {1, 3}, + GetParam()); + m.SetInput({-3, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-3, -3, 6}))); +} TEST_P(ResizeNearestNeighborOpTest, VerticalResize) { ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}, GetParam()); @@ -130,6 +138,14 @@ TEST_P(ResizeNearestNeighborOpTest, VerticalResizeInt8) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 3, -9}))); } +TEST_P(ResizeNearestNeighborOpTest, VerticalResizeInt16) { + ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 2, 1, 1}}, {3, 1}, + GetParam()); + m.SetInput({3, -9}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 3, -9}))); +} TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResize) { ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}, GetParam()); @@ -172,6 +188,20 @@ TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) { 9, 9, 12, // }))); } +TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt16) { + ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 2, 2, 1}}, {3, 3}, + GetParam()); + m.SetInput({ + 3, -6, // + 9, 12 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 3, -6, // + 3, 3, -6, // + 9, 9, 12, // + }))); +} TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) { ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}, GetParam()); @@ -284,6 +314,25 @@ TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) { 12, 12, 16, // }))); } +TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt16) { + ResizeNearestNeighborOpModel m({TensorType_INT16, {2, 2, 2, 1}}, {3, 3}, + GetParam()); + m.SetInput({ + 3, 6, // + 9, -12, // + -4, 10, // + 12, 16 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 3, 6, // + 3, 3, 6, // + 9, 9, -12, // + -4, -4, 10, // + -4, -4, 10, // + 12, 12, 16, // + }))); +} TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) { ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}, GetParam()); @@ -342,6 +391,20 @@ TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) { 10, 12, 10, 12, -14, 16, // }))); } +TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt16) { + ResizeNearestNeighborOpModel m({TensorType_INT16, {1, 2, 2, 2}}, {3, 3}, + GetParam()); + m.SetInput({ + 3, 4, -6, 10, // + 10, 12, -14, 16, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 3, 4, -6, 10, // + 3, 4, 3, 4, -6, 10, // + 10, 12, 10, 12, -14, 16, // + }))); +} INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest, ResizeNearestNeighborOpTest, testing::Values(TestType::kConst, TestType::kDynamic)); diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 02afc35de3b..c6430bf9974 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -145,6 +145,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kResizeNearestNeighbor, 1}, "1.13.1"}, {{OperatorType::kResizeNearestNeighbor, 2}, "1.14.0"}, {{OperatorType::kResizeNearestNeighbor, 3}, kPendingReleaseOpVersion}, + {{OperatorType::kResizeNearestNeighbor, 4}, kPendingReleaseOpVersion}, {{OperatorType::kSqueeze, 1}, "1.6.0"}, {{OperatorType::kSplit, 1}, "1.5.0"}, {{OperatorType::kSplit, 2}, "1.14.0"}, diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index e105c0f2d64..4567527a8c6 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -868,13 +868,18 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.version = 1; break; case BuiltinOperator_RESIZE_BILINEAR: - case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: property.inputs = {{0, {}}}; property.outputs = {{0, {}}}; property.restrict_same_input_output_scale = true; property.version = 2; property.quantizable_int16 = false; break; + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: + property.inputs = {{0, {}}}; + property.outputs = {{0, {}}}; + property.restrict_same_input_output_scale = true; + property.version = 2; + break; case BuiltinOperator_SHAPE: property.inputs = {{0, {}}}; // Shape has no quantizable output. diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 2f62230f334..de5f33e8698 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -396,8 +396,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: - if (op_sig.options.resize.half_pixel_centers || - op_sig.options.resize.align_corners) { + if (op_sig.input_types.at(0) == TensorType_INT16) { + return 4; + } else if (op_sig.options.resize.half_pixel_centers || + op_sig.options.resize.align_corners) { return 3; } else if (op_sig.input_types.at(0) == TensorType_INT8) { return 2; diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 2f13b7234e3..3f9aca06ec2 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -697,5 +697,13 @@ TEST(OpVersionTest, VersioningResizeNearestNeighborTest) { fake_op_sig.options.resize.align_corners = true; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + // int16 input is version 4. + fake_op_sig = { + .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + .input_types = std::vector{TensorType_INT16, TensorType_INT32}, + .output_types = std::vector{TensorType_INT16}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); } } // namespace tflite diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index d345164f7e6..a9d3917cf3d 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -186,6 +186,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 1}, "1.13.1"}, {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 2}, "1.14.0"}, {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 3}, "2.3.0"}, + {{BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 4}, kPendingReleaseVersion}, {{BuiltinOperator_RNN, 1}, "1.5.0"}, {{BuiltinOperator_RNN, 2}, "1.14.0"}, {{BuiltinOperator_RNN, 3}, "2.3.0"},